"src/vscode:/vscode.git/clone" did not exist on "d7ffe601664e4bd94415b2a3fbe6106c4f48f9a0"
Unverified Commit 98c73d71 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] make the `__init__` function of model_runner.py shorter (#4132)

parent fcc2e37f
......@@ -427,7 +427,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch):
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
......
......@@ -122,66 +122,17 @@ class ModelRunner:
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
# Model-specific adjustment
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
# TODO: add MLA optimization on CPU
if self.server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
logger.info(
"MLA optimization is turned on. Use flashinfer mla backend."
)
self.server_args.attention_backend = "flashinfer_mla"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self.server_args.attention_backend = "triton"
self.server_args.disable_cuda_graph = True
if self.server_args.ds_heavy_channel_type is None:
raise ValueError(
"Please specify the heavy channel type for double sparsity optimization."
)
self.init_double_sparsity_channel_config(
self.server_args.ds_heavy_channel_type
)
self.model_specific_adjustment()
if self.is_multimodal:
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_outlines_disk_cache:
from outlines.caching import disable_cache
disable_cache()
# Global vars
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
......@@ -203,6 +154,7 @@ class ModelRunner:
}
)
# CPU offload
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
# Get memory before model loading
......@@ -216,18 +168,6 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
......@@ -244,9 +184,11 @@ class ModelRunner:
else:
self.torch_tp_applied = False
# Init memory pool and attention backends
# Init lora
if server_args.lora_paths is not None:
self.init_lora_manager()
# Init memory pool and attention backends
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests,
......@@ -260,10 +202,63 @@ class ModelRunner:
self.cuda_graph_runner = None
self.init_attention_backend()
def model_specific_adjustment(self):
server_args = self.server_args
if (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
):
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
logger.info(
"MLA optimization is turned on. Use flashinfer mla backend."
)
server_args.attention_backend = "flashinfer_mla"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
server_args.attention_backend = "triton"
server_args.disable_cuda_graph = True
if server_args.ds_heavy_channel_type is None:
raise ValueError(
"Please specify the heavy channel type for double sparsity optimization."
)
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal:
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
torch.get_device_module(self.device).set_device(self.gpu_id)
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"
elif self.device == "xpu":
......@@ -400,6 +395,18 @@ class ModelRunner:
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
)
# Handle the case where some ranks do not finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
def update_weights_from_disk(
self, model_path: str, load_format: str
) -> tuple[bool, str]:
......@@ -772,6 +779,10 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
......@@ -880,18 +891,24 @@ class ModelRunner:
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
def forward(
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
) -> LogitsProcessorOutput:
if (
forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch)
):
return self.cuda_graph_runner.replay(forward_batch)
return self.cuda_graph_runner.replay(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
return self.forward_extend(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment