Unverified Commit 2ce32db6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Let reward model take text inputs instead of message lists (#1907)


Co-authored-by: default avatarKyle Corbitt <kyle@corbt.com>
parent 793b79db
...@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [ ...@@ -88,11 +88,8 @@ CONTEXT_LENGTH_KEYS = [
def get_context_length(config): def get_context_length(config):
"""Get the context length of a model from a huggingface model configs. """Get the context length of a model from a huggingface model configs."""
And here the config should be text_config part if the model is a multimodal text_config = config
LLM.
"""
text_config = getattr(config, "text_config", config)
rope_scaling = getattr(text_config, "rope_scaling", None) rope_scaling = getattr(text_config, "rope_scaling", None)
if rope_scaling: if rope_scaling:
rope_scaling_factor = rope_scaling.get("factor", 1) rope_scaling_factor = rope_scaling.get("factor", 1)
......
...@@ -238,7 +238,7 @@ class EmbeddingReqInput: ...@@ -238,7 +238,7 @@ class EmbeddingReqInput:
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = {} self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1 self.sampling_params["max_new_tokens"] = 0
else: else:
if self.rid is None: if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)] self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
...@@ -248,7 +248,7 @@ class EmbeddingReqInput: ...@@ -248,7 +248,7 @@ class EmbeddingReqInput:
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = [{}] * self.batch_size self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size): for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1 self.sampling_params[i]["max_new_tokens"] = 0
def regenerate_rid(self): def regenerate_rid(self):
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
......
...@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import ( ...@@ -34,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
...@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module): ...@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else:
return self.pooler(hidden_states, forward_batch)
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
......
...@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
def load_weights( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
] ]
params_dict = dict(self.model.named_parameters()) params_dict = dict(self.model.named_parameters())
def load_weights_per_param(name, loaded_weight): for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
return return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
load_weights_per_param(name, loaded_weight)
class MistralModel(LlamaEmbeddingModel): class MistralModel(LlamaEmbeddingModel):
pass pass
......
...@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module): ...@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states) scores = self.score(hidden_states)
......
...@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -618,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
for i, image in enumerate(forward_batch.image_inputs): for i, image in enumerate(forward_batch.image_inputs):
if image == None: if image is None:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i] prefix_len = prefix_lens_cpu[i]
......
...@@ -254,7 +254,7 @@ app.put("/encode")(encode_request) ...@@ -254,7 +254,7 @@ app.put("/encode")(encode_request)
async def judge_request(obj: EmbeddingReqInput, request: Request): async def judge_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request.""" """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try: try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
...@@ -696,24 +696,8 @@ class Runtime: ...@@ -696,24 +696,8 @@ class Runtime:
self, self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
): ):
if isinstance(prompt, str) or isinstance(prompt[0], str): json_data = {"text": prompt}
# embedding response = requests.post(self.url + "/encode", json=json_data)
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
else:
# reward
json_data = {
"conv": prompt,
}
response = requests.post(
self.url + "/judge",
json=json_data,
)
return json.dumps(response.json()) return json.dumps(response.json())
def __del__(self): def __del__(self):
......
...@@ -273,6 +273,7 @@ class SRTRunner: ...@@ -273,6 +273,7 @@ class SRTRunner:
disable_cuda_graph=disable_cuda_graph, disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache, disable_radix_cache=disable_radix_cache,
) )
self.tokenizer = get_tokenizer(model_path)
def forward( def forward(
self, self,
...@@ -366,7 +367,7 @@ class SRTRunner: ...@@ -366,7 +367,7 @@ class SRTRunner:
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
else: else:
scores = [x["embedding"][0] for x in response] scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=logits) return ModelOutput(scores=scores)
def __enter__(self): def __enter__(self):
return self return self
......
...@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16] ...@@ -30,6 +30,10 @@ TORCH_DTYPES = [torch.float16]
class TestEmbeddingModels(unittest.TestCase): class TestEmbeddingModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits( def assert_close_prefill_logits(
self, self,
prompts, prompts,
...@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -74,9 +78,4 @@ class TestEmbeddingModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main() unittest.main()
...@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16] ...@@ -63,9 +63,10 @@ TORCH_DTYPES = [torch.float16]
class TestGenerationModels(unittest.TestCase): class TestGenerationModels(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
mp.set_start_method("spawn") mp.set_start_method("spawn", force=True)
def assert_close_logits_and_output_strs( def assert_close_logits_and_output_strs(
self, self,
......
...@@ -18,10 +18,10 @@ import unittest ...@@ -18,10 +18,10 @@ import unittest
import torch import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import HFRunner, SRTRunner
MODELS = [ MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 2e-2), ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -43,6 +43,10 @@ CONVS = [ ...@@ -43,6 +43,10 @@ CONVS = [
class TestRewardModels(unittest.TestCase): class TestRewardModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_reward_scores( def assert_close_reward_scores(
self, self,
convs, convs,
...@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase): ...@@ -63,12 +67,13 @@ class TestRewardModels(unittest.TestCase):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="reward", model_type="reward",
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(convs) prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
srt_outputs = srt_runner.forward(prompts)
hf_scores = torch.tensor(hf_outputs.scores) hf_scores = torch.tensor(hf_outputs.scores)
srt_scores = torch.tensor(srt_outputs.scores) srt_scores = torch.tensor(srt_outputs.scores)
print(hf_scores) print(f"{hf_scores=}")
print(srt_scores) print(f"{srt_scores=}")
assert torch.all( assert torch.all(
abs(hf_scores - srt_scores) < tolerance abs(hf_scores - srt_scores) < tolerance
...@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase): ...@@ -83,9 +88,4 @@ class TestRewardModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main() unittest.main()
...@@ -8,7 +8,7 @@ suites = { ...@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py", "models/test_embedding_models.py",
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_lora.py", "models/test_lora.py",
# "models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_double_sparsity.py", "test_double_sparsity.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment