Unverified Commit e3e0bc50 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[Feature] SPMD for SGLang + Verl (#3852)

parent bac414ab
......@@ -149,6 +149,12 @@ jobs:
cd test/srt
python3 test_update_weights_from_distributed.py
- name: Test VerlEngine
timeout-minutes: 10
run: |
cd test/srt
python3 test_verl_engine.py
- name: Test expert parallelism (EP=2)
timeout-minutes: 10
run: |
......
import datetime
import os
import sys
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.entrypoints.verl_engine import VerlEngine
def run():
"""
Example command:
```
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
```
"""
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
def _log(text):
t = datetime.datetime.now().strftime("%H:%M:%S")
print(f"[{t}] [rank={rank}] {text}")
_log(
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
)
tp_size = 4
dp_size = 2
assert world_size == tp_size * dp_size
device_mesh_kwargs = dict(
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
)
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
_log(f"{device_mesh_cpu=}")
tp_rank = device_mesh_cpu.get_local_rank("tp")
dp_rank = device_mesh_cpu.get_local_rank("dp")
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
fragment = VerlEngine(
model_path=model_name,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=device_mesh_cpu["tp"],
base_gpu_id=dp_rank,
gpu_id_step=dp_size,
port=30000,
# for DeepSeek-V2-Lite + DP Attention
# enable_dp_attention=True, port=30000 + dp_rank * 100,
)
_log(f"{fragment=}")
prompt_all = [
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
]
prompt = prompt_all[dp_rank]
output = fragment.generate(
prompt=prompt,
sampling_params=dict(max_new_tokens=16, temperature=0.0),
)
_log(f"{prompt=} {output=}")
fragment.shutdown()
_log(f"End script")
if __name__ == "__main__":
run()
......@@ -271,10 +271,18 @@ class Engine:
self.tokenizer_manager.update_weights_from_distributed(obj, None)
)
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
"""Update weights from distributed source."""
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = True,
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
load_format=load_format,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
......@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
gpu_id = (
server_args.base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.server import Engine
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
class VerlEngine:
def __init__(
self,
device_mesh_cpu: DeviceMesh,
nnodes: int = 1,
**kwargs,
):
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size // nnodes
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
else:
self._engine = None
dist.barrier(group=self._device_mesh_cpu.get_group())
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self._tp_rank == 0:
output = self._engine.generate(
prompt=prompt,
sampling_params=sampling_params,
input_ids=input_ids,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
)
else:
output = None
# Most naive implementation, can extract tensor and send via gloo if too slow
[output] = broadcast_pyobj(
data=[output],
rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(),
)
return output
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
):
# Most naive implementation, can optimize a lot if it is bottleneck
for tensor_index, (name, tensor) in enumerate(named_tensors):
serialized_tensor = MultiprocessingSerializer.serialize(
_preprocess_tensor_for_update_weights(tensor)
)
if self._tp_rank == 0:
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
else:
gathered_serialized_tensors = None
dist.gather_object(
obj=serialized_tensor,
object_gather_list=gathered_serialized_tensors,
dst=self._device_mesh_cpu.mesh.tolist()[0],
group=self._device_mesh_cpu.get_group(),
)
if self._tp_rank == 0:
self._engine.update_weights_from_tensor(
named_tensors=[
(
name,
LocalSerializedTensor(values=gathered_serialized_tensors),
)
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
)
def release_memory_occupation(self):
if self._tp_rank == 0:
self._engine.release_memory_occupation()
def resume_memory_occupation(self):
if self._tp_rank == 0:
self._engine.resume_memory_occupation()
def shutdown(self):
if self._engine is not None:
self._engine.shutdown()
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
if isinstance(tensor, DTensor):
return tensor.full_tensor()
return tensor
......@@ -121,7 +121,7 @@ class DataParallelController:
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
)
threads.append(thread)
base_gpu_id += server_args.tp_size
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
......@@ -177,7 +177,11 @@ class DataParallelController:
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
......
......@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
@dataclass
class UpdateWeightsFromTensorReqInput:
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
load_format: Optional[str]
flush_cache: bool
@dataclass
......
......@@ -1760,8 +1760,9 @@ class Scheduler:
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
if recv_req.flush_cache:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message)
......
......@@ -205,7 +205,10 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors
),
load_format=recv_req.load_format,
)
return success, message
......
......@@ -17,7 +17,8 @@ import gc
import json
import logging
import time
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
enable_show_time_cost,
get_available_gpu_memory,
init_custom_process_group,
......@@ -514,8 +517,21 @@ class ModelRunner:
logger.error(error_msg)
return False, error_msg
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
self.model.load_weights(named_tensors)
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
for name, tensor in named_tensors
]
if load_format == "direct":
_model_load_weights_direct(self.model, named_tensors)
elif load_format is None:
self.model.load_weights(named_tensors)
else:
raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success"
def get_weights_by_name(
......@@ -836,3 +852,26 @@ class ModelRunner:
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
if isinstance(tensor, LocalSerializedTensor):
return tensor.get(tp_rank)
return tensor
@dataclass
class LocalSerializedTensor:
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i-th element in the list corresponds to i-th rank's GPU."""
values: List[bytes]
def get(self, rank: int):
return MultiprocessingSerializer.deserialize(self.values[rank])
......@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = GemmaForCausalLM
......@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = Gemma2ForCausalLM
......@@ -82,6 +82,7 @@ class ServerArgs:
dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None
base_gpu_id: int = 0
gpu_id_step: int = 1
# Logging
log_level: str = "info"
......@@ -552,6 +553,12 @@ class ServerArgs:
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)
parser.add_argument(
"--gpu-id-step",
type=int,
default=ServerArgs.gpu_id_step,
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
)
# Logging
parser.add_argument(
......@@ -957,6 +964,7 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
......
......@@ -1386,7 +1386,6 @@ def get_ip() -> str:
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
while True:
......
......@@ -21,9 +21,9 @@ import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
DEFAULT_PROMPTS = [
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
......@@ -95,9 +95,11 @@ class HFRunner:
torch_dtype: torch.dtype,
model_type: str = "generation",
output_str_only: bool = False,
trust_remote_code: bool = False,
):
self.model_type = model_type
self.output_str_only = output_str_only
self.trust_remote_code = trust_remote_code
self.in_queue = mp.Queue()
self.out_queue = mp.Queue()
......@@ -130,7 +132,7 @@ class HFRunner:
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
trust_remote_code=self.trust_remote_code,
low_cpu_mem_usage=True,
).cuda()
elif self.model_type == "embedding":
......@@ -147,7 +149,11 @@ class HFRunner:
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
self.tokenizer = get_tokenizer(
model_path,
torch_dtype=torch.dtype,
trust_remote_code=self.trust_remote_code,
)
# Run forward
while True:
......@@ -157,74 +163,15 @@ class HFRunner:
if prompts is not None:
if self.model_type == "generation":
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
).cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not self.output_str_only),
)
text = self.tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put(
ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
self.forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
base_model=self.base_model,
tokenizer=self.tokenizer,
lora_paths=lora_paths,
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
)
)
......@@ -269,6 +216,79 @@ class HFRunner:
self.model_proc.terminate()
self.in_queue = self.out_queue = None
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model,
tokenizer,
lora_paths,
torch_dtype: torch.dtype,
output_str_only: bool,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
model = PeftModel.from_pretrained(
base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
model = base_model
outputs = model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not output_str_only),
)
text = tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
del input_logits
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
class SRTRunner:
def __init__(
......@@ -284,6 +304,7 @@ class SRTRunner:
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
......@@ -293,7 +314,7 @@ class SRTRunner:
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=mem_fraction_static,
trust_remote_code=False,
trust_remote_code=trust_remote_code,
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
......@@ -301,7 +322,7 @@ class SRTRunner:
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
self.tokenizer = get_tokenizer(model_path)
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
def forward(
self,
......@@ -310,54 +331,11 @@ class SRTRunner:
lora_paths=None,
):
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
return self.forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
......@@ -379,18 +357,11 @@ class SRTRunner:
only return output strings and no logprobs
"""
if self.is_generation:
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = self.engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
return self.batch_forward_generation_raw(
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
......@@ -408,6 +379,84 @@ class SRTRunner:
self.engine.shutdown()
del self.engine
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
@staticmethod
def batch_forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
def monkey_patch_gemma2_sdpa():
"""
......@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
return config
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
def check_close_model_outputs(
hf_outputs: ModelOutput,
srt_outputs: ModelOutput,
prefill_tolerance: float,
decode_tolerance: float,
rouge_l_tolerance: float,
debug_text: str = "",
check_logprobs: bool = True,
):
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
if check_logprobs:
for i in range(len(hf_outputs.output_strs)):
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
print(
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with {debug_text} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with {debug_text} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
......@@ -536,7 +536,7 @@ def test_hellaswag_select():
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
print(f"{accuracy=}, {accuracy_gen=}")
assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(accuracy_gen - accuracy) < 0.1
assert np.abs(latency_gen - latency) < 1
return accuracy, latency
......
......@@ -74,7 +74,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.69)
self.assertGreater(accuracy, 0.65)
def test_gen_min_new_tokens(self):
test_gen_min_new_tokens()
......
......@@ -27,8 +27,13 @@ from typing import List
import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
from sglang.test.runners import (
DEFAULT_PROMPTS,
HFRunner,
SRTRunner,
check_close_model_outputs,
)
from sglang.test.test_utils import is_in_ci
@dataclasses.dataclass
......@@ -39,6 +44,7 @@ class ModelCase:
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
skip_long_prompt: bool = False
trust_remote_code: bool = False
# Popular models that run on the CI
......@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("THUDM/glm-4-9b-chat"),
ModelCase(
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
ModelCase("openai-community/gpt2"),
ModelCase("microsoft/Phi-3-small-8k-instruct"),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
......@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase):
model_path,
torch_dtype=torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
......@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase):
tp_size=model_case.tp_size,
torch_dtype=torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)):
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
print(
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=model_case.prefill_tolerance,
decode_tolerance=model_case.decode_tolerance,
rouge_l_tolerance=model_case.rouge_l_tolerance,
debug_text=f"model_path={model_path} prompts={prompts}",
)
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_ci_models(self):
for model_case in CI_MODELS:
......
......@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
engine.shutdown()
def test_update_weights_from_tensor_load_format_direct(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]
_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format="direct",
)
for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)
engine.shutdown()
def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
......
import multiprocessing
import multiprocessing as mp
import os
import random
import traceback
import unittest
from multiprocessing import Process
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import is_port_available
from sglang.test.runners import (
HFRunner,
SRTRunner,
check_close_model_outputs,
get_dtype_str,
)
from sglang.test.test_utils import is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
_TORCH_DTYPE = torch.float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
CI_MODELS = [
dict(model_path="meta-llama/Llama-3.1-8B-Instruct"),
dict(model_path="google/gemma-2-2b"),
]
ALL_OTHER_MODELS = [
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
dict(model_path="Qwen/Qwen2-1.5B"),
dict(
model_path="Qwen/Qwen2.5-14B-Instruct",
mem_fraction_static=0.4,
tp_size=8,
tight_memory=True,
decode_tolerance=1.3,
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3),
dict(model_path="allenai/OLMo-1B-0724-hf"),
dict(
model_path="THUDM/glm-4-9b-chat",
mem_fraction_static=0.1,
tp_size=8,
tight_memory=True,
),
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
dict(
model_path="ibm-granite/granite-3.0-2b-instruct",
prefill_tolerance=0.22,
decode_tolerance=0.22,
),
# Fail to run these models in test_generation_models.py, need to fix that first
# dict(model_path="openai-community/gpt2"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
]
class TestVerlEngine(unittest.TestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def assert_fragment_e2e_execution(
self,
index: int,
model_path: str,
mem_fraction_static: float = 0.4,
tp_size: int = 2,
tight_memory: bool = False,
prefill_tolerance: float = 0.1,
decode_tolerance: float = 0.1,
):
master_port = find_available_port(23456)
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
processes = []
output_reader, output_writer = mp.Pipe(duplex=False)
for tp_rank in range(tp_size):
p = Process(
target=_run_subprocess,
kwargs=dict(
tp_rank=tp_rank,
tp_size=tp_size,
master_port=master_port,
output_writer=output_writer,
model_path=model_path,
mem_fraction_static=mem_fraction_static,
tight_memory=tight_memory,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
),
)
p.start()
processes.append(p)
for _ in range(tp_size):
self.assertTrue(
output_reader.recv(),
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
)
for p in processes:
p.join()
def test_ci_models(self):
for index, model_info in enumerate(CI_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
def test_others(self):
if is_in_ci():
return
for index, model_info in enumerate(ALL_OTHER_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def _run_subprocess(
tp_rank: int,
tp_size: int,
master_port: int,
output_writer,
model_path: str,
mem_fraction_static: float,
tight_memory: bool,
prefill_tolerance: float,
decode_tolerance: float,
):
try:
print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size)
torch.cuda.set_device(tp_rank)
mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"])
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
# hf model is used for comparison
hf_model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
).cuda()
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
base_model=hf_model,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,
output_str_only=False,
)
print(
f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}",
flush=True,
)
if _ENABLE_UPDATE_WEIGHTS:
if tight_memory:
hf_model.cpu()
torch.cuda.empty_cache()
# test update weights
print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True)
fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size)
engine = VerlEngine(
model_path=model_path,
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
mem_fraction_static=mem_fraction_static,
random_seed=42,
trust_remote_code=True,
dtype=get_dtype_str(_TORCH_DTYPE),
device_mesh_cpu=inference_device_mesh_cpu["tp"],
)
print(f"subprocess[{tp_rank=}] {engine=}", flush=True)
if _ENABLE_UPDATE_WEIGHTS:
print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True)
engine.update_weights_from_tensor(
[(k, v) for k, v in fsdp_state_dict.items()]
)
for enable_batch in [False, True]:
if enable_batch:
fn = SRTRunner.batch_forward_generation_raw
else:
fn = SRTRunner.forward_generation_raw
srt_outputs = fn(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
lora_paths=None,
engine=engine,
)
print(
f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
flush=True,
)
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
rouge_l_tolerance=1,
check_logprobs=not enable_batch,
debug_text=f"{enable_batch=} {tp_rank=}",
)
execution_ok = True
except Exception as e:
print(f"subprocess[{tp_rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
engine.shutdown()
print(f"subprocess[{tp_rank=}] end", flush=True)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def _get_fsdp_state_dict(hf_model, tp_size: int):
device_mesh = init_device_mesh(
"cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"]
)
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
hf_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
print(f"{fsdp_model=}")
FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
return fsdp_model.state_dict()
# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code
def find_available_port(base_port: int):
port = base_port + random.randint(100, 1000)
while True:
if is_port_available(port):
return port
if port < 60000:
port += 42
else:
port -= 43
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment