Unverified Commit 20bd2271 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support true on-policy (#12058)

parent 64994980
......@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
......@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
class SiluAndMul(CustomOp):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if get_global_server_args().rl_on_policy_target == "fsdp":
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
......
......@@ -73,9 +73,16 @@ class RMSNorm(CustomOp):
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
cast_x_before_out_mul: bool = False,
fp32_residual: bool = False,
weight_dtype: Optional = None,
override_orig_dtype: Optional = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.cast_x_before_out_mul = cast_x_before_out_mul
self.fp32_residual = fp32_residual
self.override_orig_dtype = override_orig_dtype
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
self.variance_epsilon = eps
self.hidden_size = hidden_size
self.variance_size_override = (
......@@ -165,11 +172,14 @@ class RMSNorm(CustomOp):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
orig_dtype = x.dtype
orig_dtype = self.override_orig_dtype or x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
if self.fp32_residual:
residual = x.clone()
else:
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
......@@ -191,7 +201,12 @@ class RMSNorm(CustomOp):
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
if self.cast_x_before_out_mul:
x = self.weight * x.to(orig_dtype)
else:
x = (x * self.weight).to(orig_dtype)
if residual is None:
return x
else:
......
......@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module):
None, # bias
True, # is_vnni
)
elif get_global_server_args().rl_on_policy_target == "fsdp":
# Due to tie-weight, we may not be able to change lm_head's weight dtype
logits = torch.matmul(
hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
)
else:
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
......
......@@ -11,6 +11,7 @@ import triton
import triton.language as tl
from sglang.srt.custom_op import CustomOp
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
if get_global_server_args().rl_on_policy_target == "fsdp":
self._forward_method = self.forward_native
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
init_device = (
"cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
)
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device=init_device
)
/ self.rotary_dim
)
)
if get_global_server_args().rl_on_policy_target == "fsdp":
inv_freq = inv_freq.cuda()
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
......
......@@ -102,6 +102,14 @@ class Sampler(nn.Module):
if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1)
if get_global_server_args().rl_on_policy_target == "fsdp":
logits_div_temperature = (
logits.bfloat16().div(sampling_info.temperatures).bfloat16()
)
logprobs_via_logsoftmax_kernel = torch.log_softmax(
logits_div_temperature, dim=-1
)
# Post process logits
logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1)
......@@ -148,8 +156,11 @@ class Sampler(nn.Module):
)
if return_logprob:
if get_global_server_args().rl_on_policy_target == "fsdp":
logprobs = logprobs_via_logsoftmax_kernel
del logprobs_via_logsoftmax_kernel
# clamp to avoid -inf
if SGLANG_RETURN_ORIGINAL_LOGPROB:
elif SGLANG_RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(probs_without_temp_scaling).clamp(
min=torch.finfo(probs_without_temp_scaling.dtype).min
)
......
......@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
Qwen2Config = None
......@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x):
if get_global_server_args().rl_on_policy_target == "fsdp":
x = x.bfloat16()
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
......@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
params_dtype=(
torch.float32
if get_global_server_args().rl_on_policy_target == "fsdp"
else None
),
)
else:
self.embed_tokens = PPMissingLayer()
......@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_kwargs = (
dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
else:
self.norm = PPMissingLayer(return_tuple=True)
......
......@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
......@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
norm_kwargs = (
dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if get_global_server_args().rl_on_policy_target == "fsdp":
hidden_states = hidden_states.bfloat16()
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
if get_global_server_args().rl_on_policy_target == "fsdp":
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_kwargs = (
dict(
weight_dtype=torch.float32,
cast_x_before_out_mul=True,
override_orig_dtype=torch.float32,
fp32_residual=True,
)
if get_global_server_args().rl_on_policy_target == "fsdp"
else {}
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
)
self.layer_scatter_modes = LayerScatterModes.init_new(
......
......@@ -472,6 +472,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
rl_on_policy_target: Optional[str] = None
enable_deterministic_inference: bool = False
# Dynamic batch tokenizer
......@@ -1526,6 +1527,14 @@ class ServerArgs:
)
def _handle_deterministic_inference(self):
if self.rl_on_policy_target is not None:
logger.warning(
"Enable deterministic inference because of rl_on_policy_target."
)
self.enable_deterministic_inference = True
# TODO remove this environment variable as a whole
os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1"
if self.enable_deterministic_inference:
# Check sampling backend
self.sampling_backend = "pytorch"
......@@ -3300,6 +3309,13 @@ class ServerArgs:
)
# For deterministic inference
parser.add_argument(
"--rl-on-policy-target",
type=str,
default=ServerArgs.rl_on_policy_target,
choices=["fsdp"],
help="The training system that SGLang needs to match for true on-policy.",
)
parser.add_argument(
"--enable-deterministic-inference",
action="store_true",
......
......@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from utils import GeluAndMul, SiluAndMul, precision
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
......@@ -17,6 +18,8 @@ class TestActivation(CustomTestCase):
dtype = [torch.float16, torch.bfloat16]
def _silu_and_mul_test(self, m, n, dtype):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
......
......@@ -20,6 +20,7 @@ from utils import (
torch_w8a8_per_column_moe,
)
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
......@@ -149,6 +150,8 @@ class TestSharedExpert(CustomTestCase):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
dtype = torch.bfloat16
prepack = True
......
......@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
......@@ -96,6 +97,8 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
......@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
......@@ -35,6 +36,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
......
......@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.srt.utils import is_hip
from sglang.test.test_utils import CustomTestCase
......@@ -63,6 +64,8 @@ class TestFusedMOE(CustomTestCase):
a1_scale=None,
a2_scale=None,
):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
......@@ -9,6 +9,7 @@ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsQuantInfo
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
......@@ -56,6 +57,8 @@ class TestFusedMOE(CustomTestCase):
topk,
return_per_expert: bool = False,
):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
......@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
......@@ -40,6 +41,8 @@ def fp8_mask(a, mask):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
# Perform per-token quantization
a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
......
......@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
......@@ -116,6 +117,8 @@ def quantize_weights(
def torch_moe(a, w1, w2, score, topk):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment