Unverified Commit e3e0bc50 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[Feature] SPMD for SGLang + Verl (#3852)

parent bac414ab
...@@ -149,6 +149,12 @@ jobs: ...@@ -149,6 +149,12 @@ jobs:
cd test/srt cd test/srt
python3 test_update_weights_from_distributed.py python3 test_update_weights_from_distributed.py
- name: Test VerlEngine
timeout-minutes: 10
run: |
cd test/srt
python3 test_verl_engine.py
- name: Test expert parallelism (EP=2) - name: Test expert parallelism (EP=2)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
......
import datetime
import os
import sys
from torch.distributed.device_mesh import init_device_mesh
from sglang.srt.entrypoints.verl_engine import VerlEngine
def run():
"""
Example command:
```
torchrun --nproc_per_node=8 offline_batch_inference_torchrun.py
```
"""
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
def _log(text):
t = datetime.datetime.now().strftime("%H:%M:%S")
print(f"[{t}] [rank={rank}] {text}")
_log(
f'start {local_rank=} {rank=} {world_size=} {sys.argv=} {os.environ.get("CUDA_VISIBLE_DEVICES")}'
)
tp_size = 4
dp_size = 2
assert world_size == tp_size * dp_size
device_mesh_kwargs = dict(
mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"]
)
device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs)
_log(f"{device_mesh_cpu=}")
tp_rank = device_mesh_cpu.get_local_rank("tp")
dp_rank = device_mesh_cpu.get_local_rank("dp")
_log(f"{tp_rank=} {tp_size=} ; {dp_rank=} {dp_size=}")
model_name, mem_fraction_static = "meta-llama/Llama-3.2-1B-Instruct", 0.1
# model_name, mem_fraction_static = "meta-llama/Llama-3.1-70B-Instruct", 0.9 # test large models
# model_name, mem_fraction_static = "deepseek-ai/DeepSeek-V2-Lite", 0.8
for k in ["TORCHELASTIC_USE_AGENT_STORE"]:
if k in os.environ:
del os.environ[k]
fragment = VerlEngine(
model_path=model_name,
mem_fraction_static=mem_fraction_static,
device_mesh_cpu=device_mesh_cpu["tp"],
base_gpu_id=dp_rank,
gpu_id_step=dp_size,
port=30000,
# for DeepSeek-V2-Lite + DP Attention
# enable_dp_attention=True, port=30000 + dp_rank * 100,
)
_log(f"{fragment=}")
prompt_all = [
["1+1=2, 1+2=3, 1+3=4, 1+4=", "9-1=8, 8-1=7, 7-1="],
["2*1=2, 2*2=4, 2*3=", "8/2=4, 6/2="],
]
prompt = prompt_all[dp_rank]
output = fragment.generate(
prompt=prompt,
sampling_params=dict(max_new_tokens=16, temperature=0.0),
)
_log(f"{prompt=} {output=}")
fragment.shutdown()
_log(f"End script")
if __name__ == "__main__":
run()
...@@ -271,10 +271,18 @@ class Engine: ...@@ -271,10 +271,18 @@ class Engine:
self.tokenizer_manager.update_weights_from_distributed(obj, None) self.tokenizer_manager.update_weights_from_distributed(obj, None)
) )
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): def update_weights_from_tensor(
"""Update weights from distributed source.""" self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = True,
):
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
to avoid duplicated operations such as clearing cache."""
obj = UpdateWeightsFromTensorReqInput( obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
load_format=load_format,
flush_cache=flush_cache,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
...@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic ...@@ -384,7 +392,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
) )
for tp_rank in tp_rank_range: for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node gpu_id = (
server_args.base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer), args=(server_args, port_args, gpu_id, tp_rank, None, writer),
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.server import Engine
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
class VerlEngine:
def __init__(
self,
device_mesh_cpu: DeviceMesh,
nnodes: int = 1,
**kwargs,
):
self._device_mesh_cpu = device_mesh_cpu
self._tp_rank = device_mesh_cpu.get_local_rank()
self._tp_size = device_mesh_cpu.size()
tp_size_per_node = self._tp_size // nnodes
node_rank = self._tp_rank // tp_size_per_node
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
if first_rank_in_node:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
self._engine = Engine(
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
)
else:
self._engine = None
dist.barrier(group=self._device_mesh_cpu.get_group())
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self._tp_rank == 0:
output = self._engine.generate(
prompt=prompt,
sampling_params=sampling_params,
input_ids=input_ids,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
)
else:
output = None
# Most naive implementation, can extract tensor and send via gloo if too slow
[output] = broadcast_pyobj(
data=[output],
rank=self._tp_rank,
dist_group=self._device_mesh_cpu.get_group(),
src=self._device_mesh_cpu.mesh[0].item(),
)
return output
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
):
# Most naive implementation, can optimize a lot if it is bottleneck
for tensor_index, (name, tensor) in enumerate(named_tensors):
serialized_tensor = MultiprocessingSerializer.serialize(
_preprocess_tensor_for_update_weights(tensor)
)
if self._tp_rank == 0:
gathered_serialized_tensors = [None for _ in range(self._tp_size)]
else:
gathered_serialized_tensors = None
dist.gather_object(
obj=serialized_tensor,
object_gather_list=gathered_serialized_tensors,
dst=self._device_mesh_cpu.mesh.tolist()[0],
group=self._device_mesh_cpu.get_group(),
)
if self._tp_rank == 0:
self._engine.update_weights_from_tensor(
named_tensors=[
(
name,
LocalSerializedTensor(values=gathered_serialized_tensors),
)
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
)
def release_memory_occupation(self):
if self._tp_rank == 0:
self._engine.release_memory_occupation()
def resume_memory_occupation(self):
if self._tp_rank == 0:
self._engine.resume_memory_occupation()
def shutdown(self):
if self._engine is not None:
self._engine.shutdown()
def _preprocess_tensor_for_update_weights(tensor: torch.Tensor):
if isinstance(tensor, DTensor):
return tensor.full_tensor()
return tensor
...@@ -121,7 +121,7 @@ class DataParallelController: ...@@ -121,7 +121,7 @@ class DataParallelController:
args=(server_args, tmp_port_args, base_gpu_id, dp_rank), args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
) )
threads.append(thread) threads.append(thread)
base_gpu_id += server_args.tp_size base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers # Free all sockets before starting the threads to launch TP workers
for sock in sockets: for sock in sockets:
...@@ -177,7 +177,11 @@ class DataParallelController: ...@@ -177,7 +177,11 @@ class DataParallelController:
rank_port_args.nccl_port = port_args.nccl_port rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
......
...@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput: ...@@ -449,6 +449,8 @@ class UpdateWeightsFromDistributedReqOutput:
@dataclass @dataclass
class UpdateWeightsFromTensorReqInput: class UpdateWeightsFromTensorReqInput:
serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor] serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
load_format: Optional[str]
flush_cache: bool
@dataclass @dataclass
......
...@@ -1760,8 +1760,9 @@ class Scheduler: ...@@ -1760,8 +1760,9 @@ class Scheduler:
success, message = self.tp_worker.update_weights_from_tensor(recv_req) success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success: if success:
flash_cache_success = self.flush_cache() if recv_req.flush_cache:
assert flash_cache_success, "Cache flush failed after updating weights" flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else: else:
logger.error(message) logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message) return UpdateWeightsFromTensorReqOutput(success, message)
......
...@@ -205,7 +205,10 @@ class TpModelWorker: ...@@ -205,7 +205,10 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor( success, message = self.model_runner.update_weights_from_tensor(
MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors) named_tensors=MultiprocessingSerializer.deserialize(
recv_req.serialized_named_tensors
),
load_format=recv_req.load_format,
) )
return success, message return success, message
......
...@@ -17,7 +17,8 @@ import gc ...@@ -17,7 +17,8 @@ import gc
import json import json
import logging import logging
import time import time
from typing import List, Optional, Tuple from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -56,10 +57,12 @@ from sglang.srt.mem_cache.memory_pool import (
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
...@@ -514,8 +517,21 @@ class ModelRunner: ...@@ -514,8 +517,21 @@ class ModelRunner:
logger.error(error_msg) logger.error(error_msg)
return False, error_msg return False, error_msg
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): def update_weights_from_tensor(
self.model.load_weights(named_tensors) self,
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
for name, tensor in named_tensors
]
if load_format == "direct":
_model_load_weights_direct(self.model, named_tensors)
elif load_format is None:
self.model.load_weights(named_tensors)
else:
raise NotImplementedError(f"Unknown load_format={load_format}")
return True, "Success" return True, "Success"
def get_weights_by_name( def get_weights_by_name(
...@@ -836,3 +852,26 @@ class ModelRunner: ...@@ -836,3 +852,26 @@ class ModelRunner:
if rope_scaling is None: if rope_scaling is None:
return False return False
return rope_scaling.get("type", None) == "mrope" return rope_scaling.get("type", None) == "mrope"
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
if isinstance(tensor, LocalSerializedTensor):
return tensor.get(tp_rank)
return tensor
@dataclass
class LocalSerializedTensor:
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i-th element in the list corresponds to i-th rank's GPU."""
values: List[bytes]
def get(self, rank: int):
return MultiprocessingSerializer.deserialize(self.values[rank])
...@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module): ...@@ -336,12 +336,6 @@ class GemmaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = GemmaForCausalLM EntryClass = GemmaForCausalLM
...@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -437,12 +437,5 @@ class Gemma2ForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}"
)
EntryClass = Gemma2ForCausalLM EntryClass = Gemma2ForCausalLM
...@@ -82,6 +82,7 @@ class ServerArgs: ...@@ -82,6 +82,7 @@ class ServerArgs:
dist_timeout: Optional[int] = None # timeout for torch.distributed dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None download_dir: Optional[str] = None
base_gpu_id: int = 0 base_gpu_id: int = 0
gpu_id_step: int = 1
# Logging # Logging
log_level: str = "info" log_level: str = "info"
...@@ -552,6 +553,12 @@ class ServerArgs: ...@@ -552,6 +553,12 @@ class ServerArgs:
default=ServerArgs.base_gpu_id, default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.", help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
) )
parser.add_argument(
"--gpu-id-step",
type=int,
default=ServerArgs.gpu_id_step,
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
)
# Logging # Logging
parser.add_argument( parser.add_argument(
...@@ -957,6 +964,7 @@ class ServerArgs: ...@@ -957,6 +964,7 @@ class ServerArgs:
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
if isinstance(self.lora_paths, list): if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths lora_paths = self.lora_paths
......
...@@ -1386,7 +1386,6 @@ def get_ip() -> str: ...@@ -1386,7 +1386,6 @@ def get_ip() -> str:
def get_open_port() -> int: def get_open_port() -> int:
port = os.getenv("SGLANG_PORT") port = os.getenv("SGLANG_PORT")
if port is not None: if port is not None:
while True: while True:
......
...@@ -21,9 +21,9 @@ import torch ...@@ -21,9 +21,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
"Apple is red. Banana is Yellow. " * 800 + "Apple is", "Apple is red. Banana is Yellow. " * 800 + "Apple is",
...@@ -95,9 +95,11 @@ class HFRunner: ...@@ -95,9 +95,11 @@ class HFRunner:
torch_dtype: torch.dtype, torch_dtype: torch.dtype,
model_type: str = "generation", model_type: str = "generation",
output_str_only: bool = False, output_str_only: bool = False,
trust_remote_code: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.output_str_only = output_str_only self.output_str_only = output_str_only
self.trust_remote_code = trust_remote_code
self.in_queue = mp.Queue() self.in_queue = mp.Queue()
self.out_queue = mp.Queue() self.out_queue = mp.Queue()
...@@ -130,7 +132,7 @@ class HFRunner: ...@@ -130,7 +132,7 @@ class HFRunner:
self.base_model = AutoModelForCausalLM.from_pretrained( self.base_model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
trust_remote_code=False, trust_remote_code=self.trust_remote_code,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).cuda() ).cuda()
elif self.model_type == "embedding": elif self.model_type == "embedding":
...@@ -147,7 +149,11 @@ class HFRunner: ...@@ -147,7 +149,11 @@ class HFRunner:
).cuda() ).cuda()
else: else:
raise Exception(f"Unrecognized model type {self.model_type}") raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype) self.tokenizer = get_tokenizer(
model_path,
torch_dtype=torch.dtype,
trust_remote_code=self.trust_remote_code,
)
# Run forward # Run forward
while True: while True:
...@@ -157,74 +163,15 @@ class HFRunner: ...@@ -157,74 +163,15 @@ class HFRunner:
if prompts is not None: if prompts is not None:
if self.model_type == "generation": if self.model_type == "generation":
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = self.tokenizer.encode(
p, return_tensors="pt"
).cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
self.model = self.base_model
outputs = self.model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not self.output_str_only),
)
text = self.tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not self.output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(
logits[0], NUM_TOP_LOGPROBS
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = self.model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(
input_logits, NUM_TOP_LOGPROBS
).tolist()
)
del input_logits
out_queue.put( out_queue.put(
ModelOutput( self.forward_generation_raw(
output_strs=output_strs, prompts=prompts,
top_input_logprobs=top_input_logprobs, max_new_tokens=max_new_tokens,
top_output_logprobs=top_output_logprobs, base_model=self.base_model,
tokenizer=self.tokenizer,
lora_paths=lora_paths,
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
) )
) )
...@@ -269,6 +216,79 @@ class HFRunner: ...@@ -269,6 +216,79 @@ class HFRunner:
self.model_proc.terminate() self.model_proc.terminate()
self.in_queue = self.out_queue = None self.in_queue = self.out_queue = None
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model,
tokenizer,
lora_paths,
torch_dtype: torch.dtype,
output_str_only: bool,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
else:
input_ids = torch.tensor([p], device="cuda")
if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
model = PeftModel.from_pretrained(
base_model,
lora_paths[i],
torch_dtype=torch_dtype,
is_trainable=False,
)
else:
model = base_model
outputs = model.generate(
input_ids,
do_sample=False,
temperature=None,
top_p=None,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=(not output_str_only),
)
text = tokenizer.decode(
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
)
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
if not output_str_only:
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs.append(
[
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
del input_logits
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
class SRTRunner: class SRTRunner:
def __init__( def __init__(
...@@ -284,6 +304,7 @@ class SRTRunner: ...@@ -284,6 +304,7 @@ class SRTRunner:
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
mem_fraction_static: float = 0.65, mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -293,7 +314,7 @@ class SRTRunner: ...@@ -293,7 +314,7 @@ class SRTRunner:
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
mem_fraction_static=mem_fraction_static, mem_fraction_static=mem_fraction_static,
trust_remote_code=False, trust_remote_code=trust_remote_code,
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
lora_paths=lora_paths, lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch, max_loras_per_batch=max_loras_per_batch,
...@@ -301,7 +322,7 @@ class SRTRunner: ...@@ -301,7 +322,7 @@ class SRTRunner:
disable_cuda_graph=disable_cuda_graph, disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache, disable_radix_cache=disable_radix_cache,
) )
self.tokenizer = get_tokenizer(model_path) self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
def forward( def forward(
self, self,
...@@ -310,54 +331,11 @@ class SRTRunner: ...@@ -310,54 +331,11 @@ class SRTRunner:
lora_paths=None, lora_paths=None,
): ):
if self.is_generation: if self.is_generation:
# the return value contains logprobs from prefill return self.forward_generation_raw(
output_strs = [] prompts=prompts,
top_input_logprobs = [] max_new_tokens=max_new_tokens,
top_output_logprobs = [] lora_paths=lora_paths,
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} engine=self.engine,
for i, prompt in enumerate(prompts):
response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
) )
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts)
...@@ -379,18 +357,11 @@ class SRTRunner: ...@@ -379,18 +357,11 @@ class SRTRunner:
only return output strings and no logprobs only return output strings and no logprobs
""" """
if self.is_generation: if self.is_generation:
# the return value contains logprobs from prefill return self.batch_forward_generation_raw(
output_strs = [] prompts=prompts,
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} max_new_tokens=max_new_tokens,
response = self.engine.generate( lora_paths=lora_paths,
prompts, engine=self.engine,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
) )
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts)
...@@ -408,6 +379,84 @@ class SRTRunner: ...@@ -408,6 +379,84 @@ class SRTRunner:
self.engine.shutdown() self.engine.shutdown()
del self.engine del self.engine
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
response = engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
text = response["text"]
# Check if the text is empty or only whitespace.
if not text.strip():
raise ValueError(
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
)
top_output_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
)
@staticmethod
def batch_forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
):
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
response = engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
def monkey_patch_gemma2_sdpa(): def monkey_patch_gemma2_sdpa():
""" """
...@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa(): ...@@ -422,3 +471,52 @@ def monkey_patch_gemma2_sdpa():
return config return config
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa) setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
def check_close_model_outputs(
hf_outputs: ModelOutput,
srt_outputs: ModelOutput,
prefill_tolerance: float,
decode_tolerance: float,
rouge_l_tolerance: float,
debug_text: str = "",
check_logprobs: bool = True,
):
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
if check_logprobs:
for i in range(len(hf_outputs.output_strs)):
# Compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
input_len = hf_logprobs.shape[0]
print(
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with {debug_text} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with {debug_text} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
...@@ -536,7 +536,7 @@ def test_hellaswag_select(): ...@@ -536,7 +536,7 @@ def test_hellaswag_select():
# Compute accuracy # Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
print(f"{accuracy=}, {accuracy_gen=}") print(f"{accuracy=}, {accuracy_gen=}")
assert np.abs(accuracy_gen - accuracy) < 0.05 assert np.abs(accuracy_gen - accuracy) < 0.1
assert np.abs(latency_gen - latency) < 1 assert np.abs(latency_gen - latency) < 1
return accuracy, latency return accuracy, latency
......
...@@ -74,7 +74,7 @@ class TestSRTBackend(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs # Run twice to capture more bugs
for _ in range(2): for _ in range(2):
accuracy, latency = test_hellaswag_select() accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.69) self.assertGreater(accuracy, 0.65)
def test_gen_min_new_tokens(self): def test_gen_min_new_tokens(self):
test_gen_min_new_tokens() test_gen_min_new_tokens()
......
...@@ -27,8 +27,13 @@ from typing import List ...@@ -27,8 +27,13 @@ from typing import List
import torch import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import (
from sglang.test.test_utils import calculate_rouge_l, is_in_ci DEFAULT_PROMPTS,
HFRunner,
SRTRunner,
check_close_model_outputs,
)
from sglang.test.test_utils import is_in_ci
@dataclasses.dataclass @dataclasses.dataclass
...@@ -39,6 +44,7 @@ class ModelCase: ...@@ -39,6 +44,7 @@ class ModelCase:
decode_tolerance: float = 5e-2 decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1 rouge_l_tolerance: float = 1
skip_long_prompt: bool = False skip_long_prompt: bool = False
trust_remote_code: bool = False
# Popular models that run on the CI # Popular models that run on the CI
...@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [ ...@@ -53,7 +59,9 @@ ALL_OTHER_MODELS = [
ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("THUDM/glm-4-9b-chat"), ModelCase(
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
ModelCase("openai-community/gpt2"), ModelCase("openai-community/gpt2"),
ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("microsoft/Phi-3-small-8k-instruct"),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
...@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase): ...@@ -87,6 +95,7 @@ class TestGenerationModels(unittest.TestCase):
model_path, model_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="generation", model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as hf_runner: ) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
...@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase): ...@@ -95,48 +104,18 @@ class TestGenerationModels(unittest.TestCase):
tp_size=model_case.tp_size, tp_size=model_case.tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="generation", model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as srt_runner: ) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
for i in range(len(prompts)): check_close_model_outputs(
# Compare input logprobs hf_outputs=hf_outputs,
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) srt_outputs=srt_outputs,
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) prefill_tolerance=model_case.prefill_tolerance,
input_len = hf_logprobs.shape[0] decode_tolerance=model_case.decode_tolerance,
print( rouge_l_tolerance=model_case.rouge_l_tolerance,
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) debug_text=f"model_path={model_path} prompts={prompts}",
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
)
if input_len <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# Compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
rouge_l_scores = calculate_rouge_l(
hf_outputs.output_strs, srt_outputs.output_strs
) )
print(f"{rouge_l_scores=}")
assert all(
score >= rouge_l_tolerance for score in rouge_l_scores
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_ci_models(self): def test_ci_models(self):
for model_case in CI_MODELS: for model_case in CI_MODELS:
......
...@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase): ...@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
engine.shutdown() engine.shutdown()
def test_update_weights_from_tensor_load_format_direct(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]
_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format="direct",
)
for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)
engine.shutdown()
def _check_param(engine, param_name, expect_values): def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
......
import multiprocessing
import multiprocessing as mp
import os
import random
import traceback
import unittest
from multiprocessing import Process
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import is_port_available
from sglang.test.runners import (
HFRunner,
SRTRunner,
check_close_model_outputs,
get_dtype_str,
)
from sglang.test.test_utils import is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
_TORCH_DTYPE = torch.float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
CI_MODELS = [
dict(model_path="meta-llama/Llama-3.1-8B-Instruct"),
dict(model_path="google/gemma-2-2b"),
]
ALL_OTHER_MODELS = [
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
dict(model_path="Qwen/Qwen2-1.5B"),
dict(
model_path="Qwen/Qwen2.5-14B-Instruct",
mem_fraction_static=0.4,
tp_size=8,
tight_memory=True,
decode_tolerance=1.3,
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
dict(model_path="HuggingFaceTB/SmolLM-135M-Instruct", tp_size=3),
dict(model_path="allenai/OLMo-1B-0724-hf"),
dict(
model_path="THUDM/glm-4-9b-chat",
mem_fraction_static=0.1,
tp_size=8,
tight_memory=True,
),
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
dict(
model_path="ibm-granite/granite-3.0-2b-instruct",
prefill_tolerance=0.22,
decode_tolerance=0.22,
),
# Fail to run these models in test_generation_models.py, need to fix that first
# dict(model_path="openai-community/gpt2"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
]
class TestVerlEngine(unittest.TestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def assert_fragment_e2e_execution(
self,
index: int,
model_path: str,
mem_fraction_static: float = 0.4,
tp_size: int = 2,
tight_memory: bool = False,
prefill_tolerance: float = 0.1,
decode_tolerance: float = 0.1,
):
master_port = find_available_port(23456)
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
processes = []
output_reader, output_writer = mp.Pipe(duplex=False)
for tp_rank in range(tp_size):
p = Process(
target=_run_subprocess,
kwargs=dict(
tp_rank=tp_rank,
tp_size=tp_size,
master_port=master_port,
output_writer=output_writer,
model_path=model_path,
mem_fraction_static=mem_fraction_static,
tight_memory=tight_memory,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
),
)
p.start()
processes.append(p)
for _ in range(tp_size):
self.assertTrue(
output_reader.recv(),
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
)
for p in processes:
p.join()
def test_ci_models(self):
for index, model_info in enumerate(CI_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
def test_others(self):
if is_in_ci():
return
for index, model_info in enumerate(ALL_OTHER_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def _run_subprocess(
tp_rank: int,
tp_size: int,
master_port: int,
output_writer,
model_path: str,
mem_fraction_static: float,
tight_memory: bool,
prefill_tolerance: float,
decode_tolerance: float,
):
try:
print(f"subprocess[{tp_rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
torch.distributed.init_process_group(rank=tp_rank, world_size=tp_size)
torch.cuda.set_device(tp_rank)
mesh_kwargs = dict(mesh_shape=(tp_size, 1), mesh_dim_names=["tp", "pp"])
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{tp_rank=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
# hf model is used for comparison
hf_model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
).cuda()
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
base_model=hf_model,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,
output_str_only=False,
)
print(
f"subprocess[{tp_rank=}] call hf.forward {hf_outputs=}",
flush=True,
)
if _ENABLE_UPDATE_WEIGHTS:
if tight_memory:
hf_model.cpu()
torch.cuda.empty_cache()
# test update weights
print(f"subprocess[{tp_rank=}] get_fsdp_state_dict", flush=True)
fsdp_state_dict = _get_fsdp_state_dict(hf_model=hf_model, tp_size=tp_size)
engine = VerlEngine(
model_path=model_path,
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
mem_fraction_static=mem_fraction_static,
random_seed=42,
trust_remote_code=True,
dtype=get_dtype_str(_TORCH_DTYPE),
device_mesh_cpu=inference_device_mesh_cpu["tp"],
)
print(f"subprocess[{tp_rank=}] {engine=}", flush=True)
if _ENABLE_UPDATE_WEIGHTS:
print(f"subprocess[{tp_rank=}] call update_weights_from_tensor", flush=True)
engine.update_weights_from_tensor(
[(k, v) for k, v in fsdp_state_dict.items()]
)
for enable_batch in [False, True]:
if enable_batch:
fn = SRTRunner.batch_forward_generation_raw
else:
fn = SRTRunner.forward_generation_raw
srt_outputs = fn(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
lora_paths=None,
engine=engine,
)
print(
f"subprocess[{tp_rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
flush=True,
)
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
rouge_l_tolerance=1,
check_logprobs=not enable_batch,
debug_text=f"{enable_batch=} {tp_rank=}",
)
execution_ok = True
except Exception as e:
print(f"subprocess[{tp_rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
engine.shutdown()
print(f"subprocess[{tp_rank=}] end", flush=True)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def _get_fsdp_state_dict(hf_model, tp_size: int):
device_mesh = init_device_mesh(
"cuda", mesh_shape=(tp_size,), mesh_dim_names=["fsdp"]
)
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
hf_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
print(f"{fsdp_model=}")
FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
return fsdp_model.state_dict()
# TODO Ask: this is extracted from PortArgs.init_new, is it allowed to extract it, i.e. touch that old code
def find_available_port(base_port: int):
port = base_port + random.randint(100, 1000)
while True:
if is_port_available(port):
return port
if port < 60000:
port += 42
else:
port -= 43
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment