"vscode:/vscode.git/clone" did not exist on "50962c433fd2bc55eedb667c99ea5f6717f7debc"
Unverified Commit 444a0244 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Update vllm version to support llama3.1 (#705)

parent fa7ccb33
......@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"]
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
......
......@@ -73,6 +73,8 @@ def get_context_length(config):
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]
if config.rope_scaling["rope_type"] == "llama3":
rope_scaling_factor = 1
else:
rope_scaling_factor = 1
......
......@@ -5,14 +5,10 @@
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
......@@ -375,9 +371,6 @@ class LlamaForCausalLM(nn.Module):
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
if get_tensor_model_parallel_rank() == 0:
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
for name, loaded_weight in weights:
load_weights_per_param(name, loaded_weight)
else:
......
......@@ -222,6 +222,7 @@ def launch_server(
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
logger.info(f"{server_args=}")
# Handle multi-node tensor parallelism
if server_args.nnodes > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment