Unverified Commit d17986f8 authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

Enable optional FP32 compute for LM Head (#10729)

Thanks to MiniMax Team and Chenyang Zhao's support.
parent 8831c55c
......@@ -113,6 +113,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--quantization` | The quantization method. | None |
| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None |
| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto |
| `--enable-fp32-lm-head` | If set, the LM head outputs (logits) are in FP32. | False |
## Memory and scheduling
......
......@@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module):
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
......@@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
if use_intel_amx_backend(lm_head):
if self.use_fp32_lm_head:
logits = torch.matmul(
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
)
elif use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
......@@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module):
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.use_fp32_lm_head:
with torch.cuda.amp.autocast(enabled=False):
logits = lm_head.quant_method.apply(
lm_head, hidden_states.to(torch.float32), embedding_bias
)
else:
logits = lm_head.quant_method.apply(
lm_head, hidden_states, embedding_bias
)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)
......
......@@ -90,6 +90,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
......
......@@ -167,6 +167,7 @@ class ServerArgs:
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
enable_fp32_lm_head: bool = False
# Memory and scheduling
mem_fraction_static: Optional[float] = None
......@@ -1392,6 +1393,11 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--enable-fp32-lm-head",
action="store_true",
help="If set, the LM head outputs (logits) are in FP32.",
)
# Memory and scheduling
parser.add_argument(
......
import unittest
from types import SimpleNamespace
from unittest.mock import patch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.managers.schedule_batch import global_server_args_dict
class LMHeadStub(nn.Module):
def __init__(self, vocab, hidden, dtype, device="cuda"):
super().__init__()
self.weight = nn.Parameter(
torch.randn(vocab, hidden, dtype=dtype, device=device)
)
class DummyMeta:
gathered_buffer = None
next_token_logits_buffer = None
def compute_dp_attention_metadata(self): ...
class TestLMHeadFP32(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("needs CUDA GPU")
def _make_logprocessor(self, vocab_size, enable_fp32):
global_server_args_dict["enable_dp_lm_head"] = False
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
def _run_case(
self,
hidden_state_dtype,
enable_fp32,
weights_dtype,
expected_a_dtype,
expected_b_dtype,
):
device = "cuda"
BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128
hidden_state = torch.randn(
BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device
)
head = LMHeadStub(VOCAB_SIZE, HIDDEN_SIZE, dtype=weights_dtype, device=device)
meta = DummyMeta()
logprocessor = self._make_logprocessor(VOCAB_SIZE, enable_fp32)
original_matmul = torch.matmul
original_linear = F.linear
state = {
"called": False, # Whether a matmul/linear call has been intercepted yet
"operation": None, # Which operation was captured ("matmul" or "linear")
"a": None, # The dtype of the first input tensor to the operation
"b": None, # The dtype of the second input tensor to the operation
}
def probe_matmul(a, b, *args, **kw):
if not state["called"]:
state.update(called=True, operation="matmul", a=a.dtype, b=b.dtype)
return original_matmul(a, b, *args, **kw)
def probe_linear(x, w, bias=None):
if not state["called"]:
state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype)
return original_linear(x, w, bias)
with patch("torch.matmul", new=probe_matmul), patch(
"torch.nn.functional.linear", new=probe_linear
):
logits = logprocessor._get_logits(hidden_state, head, meta)
self.assertEqual(hidden_state.dtype, hidden_state_dtype)
self.assertTrue(state["called"], "no call lm head matlmul/linear")
self.assertEqual(state["a"], expected_a_dtype)
self.assertEqual(state["b"], expected_b_dtype)
def test_flag_true_fp16_activations(self):
self._run_case(torch.float16, True, torch.float16, torch.float32, torch.float32)
def test_flag_true_bf16_activations(self):
self._run_case(
torch.bfloat16, True, torch.bfloat16, torch.float32, torch.float32
)
def test_flag_false_fp16_path(self):
self._run_case(
torch.float16, False, torch.float16, torch.float16, torch.float16
)
def test_flag_false_bf16_path(self):
self._run_case(
torch.bfloat16, False, torch.bfloat16, torch.bfloat16, torch.bfloat16
)
if __name__ == "__main__":
unittest.main(verbosity=2)
......@@ -59,6 +59,7 @@ suites = {
TestFile("quant/test_int8_kernel.py", 8),
TestFile("quant/test_triton_scaled_mm.py", 8),
TestFile("quant/test_w8a8_quantization.py", 46),
TestFile("rl/test_fp32_lm_head.py", 30),
TestFile("rl/test_update_weights_from_disk.py", 114),
TestFile("rl/test_update_weights_from_tensor.py", 48),
TestFile("test_abort.py", 51),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment