"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "3d682e5ce3d56445484934135f660660cfbacfbd"
Unverified Commit 8fbba3de authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix bugs (fp8 checkpoints, triton cache manager) (#729)

parent ae0f6130
...@@ -70,11 +70,6 @@ docker run --gpus all \ ...@@ -70,11 +70,6 @@ docker run --gpus all \
``` ```
### Common Notes ### Common Notes
- If you see errors from the Triton compiler, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html) by
```
pip uninstall -y triton triton-nightly
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```
- If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server. - If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
...@@ -157,6 +152,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -157,6 +152,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
``` ```
- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/custom_chat_template.md). - If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/custom_chat_template.md).
- To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. - To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.
### Supported Models ### Supported Models
......
...@@ -30,9 +30,11 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool ...@@ -30,9 +30,11 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
is_llama3_405b_fp8,
is_multimodal_model, is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
monkey_patch_vllm_qvk_linear_loader,
) )
logger = logging.getLogger("srt.model_runner") logger = logging.getLogger("srt.model_runner")
...@@ -118,6 +120,13 @@ class ModelRunner: ...@@ -118,6 +120,13 @@ class ModelRunner:
seed=42, seed=42,
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
if is_llama3_405b_fp8(self.model_config):
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
self.dtype = vllm_model_config.dtype self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None: if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args) vllm_model_config.hf_config.update(self.model_config.model_overide_args)
......
...@@ -202,15 +202,12 @@ def launch_server( ...@@ -202,15 +202,12 @@ def launch_server(
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
if server_args.tp_size * server_args.dp_size > 1:
if server_args.tp_size // server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager() maybe_set_triton_cache_manager()
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
if server_args.enable_torch_compile: if server_args.enable_torch_compile:
_set_torch_compile_config() _set_torch_compile_config()
......
...@@ -21,6 +21,7 @@ import torch.distributed as dist ...@@ -21,6 +21,7 @@ import torch.distributed as dist
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from torch.nn.parameter import Parameter
from triton.runtime.cache import ( from triton.runtime.cache import (
FileCacheManager, FileCacheManager,
default_cache_dir, default_cache_dir,
...@@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None: ...@@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None: if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager" manager = "sglang.srt.utils:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager) logger.debug("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager os.environ["TRITON_CACHE_MANAGER"] = manager
...@@ -615,3 +616,51 @@ def set_ulimit(target_soft_limit=65535): ...@@ -615,3 +616,51 @@ def set_ulimit(target_soft_limit=65535):
resource.setrlimit(resource_type, (target_soft_limit, current_hard)) resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e: except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.num_key_value_heads == 16
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False
def monkey_patch_vllm_qvk_linear_loader():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from vllm.model_executor.layers.linear import QKVParallelLinear
origin_weight_loader = QKVParallelLinear.weight_loader
def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment