Commit 98a011e9 authored by zhuwenwen's avatar zhuwenwen
Browse files

restore the initial fp8 related implementation

remove medusa related files
parent 80c483dd
...@@ -266,15 +266,15 @@ set(VLLM_EXT_SRC ...@@ -266,15 +266,15 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_with_mask_kernels_opt.cu" "csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu" "csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
# "csrc/layernorm_quant_kernels.cu" "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu" "csrc/sampler.cu"
"csrc/cuda_view.cu" "csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu" # "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu" "csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu" "csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp") "csrc/torch_bindings.cpp")
......
...@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
# "-DENABLE_FP8" "-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-Werror=unused-variable" "-Werror=unused-variable"
......
...@@ -6,9 +6,7 @@ ...@@ -6,9 +6,7 @@
*/ */
#include "type_convert.cuh" #include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh" #include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
......
...@@ -224,15 +224,15 @@ void apply_repetition_penalties_(torch::Tensor& logits, ...@@ -224,15 +224,15 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties); const torch::Tensor& repetition_penalties);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale, torch::Tensor& weight, torch::Tensor& scale,
// double epsilon); double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input, torch::Tensor& input,
// torch::Tensor& residual, torch::Tensor& residual,
// torch::Tensor& weight, torch::Tensor& weight,
// torch::Tensor& scale, double epsilon); torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input, torch::Tensor const& input,
...@@ -248,8 +248,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ...@@ -248,8 +248,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale); torch::Tensor& scale);
#ifndef USE_ROCM #ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out, void silu_and_mul_nvfp4_quant(torch::Tensor& out,
...@@ -257,12 +257,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, ...@@ -257,12 +257,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input, torch::Tensor& input,
torch::Tensor& input_global_scale); torch::Tensor& input_global_scale);
#endif #endif
// void silu_mul_fp8_quant_deep_gemm_cuda( void silu_mul_fp8_quant_deep_gemm_cuda(
// const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
// const at::Tensor& counts, // (E) const at::Tensor& counts, // (E)
// at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_q, // (E, T, H) [OUT]
// at::Tensor& y_s, // (E, T, H//group_size) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT]
// int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
...@@ -438,15 +438,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -438,15 +438,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); // void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale); torch::Tensor const& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale); torch::Tensor& scale);
// void dynamic_per_token_scaled_fp8_quant( void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub); std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& A, const torch::Tensor& B,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh" #include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
// #include "quantization/fp8/common.cuh" #include "quantization/fp8/common.cuh"
namespace vllm { namespace vllm {
......
...@@ -32,12 +32,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -32,12 +32,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#define stride_tag #define stride_tag
#endif #endif
// ops.def( ops.def(
// "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
// "y_q, Tensor! y_s, int group_size, " "y_q, Tensor! y_s, int group_size, "
// "bool use_ue8m0, int num_parallel_tokens) -> ()"); "bool use_ue8m0, int num_parallel_tokens) -> ()");
// ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
// &silu_mul_fp8_quant_deep_gemm_cuda); &silu_mul_fp8_quant_deep_gemm_cuda);
ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
...@@ -269,9 +269,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -269,9 +269,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// ops.def( ops.def(
// "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
// ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def( ops.def(
...@@ -366,20 +366,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -366,20 +366,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
// ops.def( ops.def(
// "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
// "Tensor scale, float epsilon) -> " "Tensor scale, float epsilon) -> "
// "()"); "()");
// ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
// &rms_norm_static_fp8_quant); &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization. // In-place fused Add and RMS Normalization.
// ops.def( ops.def(
// "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! residual, Tensor weight, " "Tensor! residual, Tensor weight, "
// "Tensor scale, float epsilon) -> ()"); "Tensor scale, float epsilon) -> ()");
// ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
// &fused_add_rms_norm_static_fp8_quant); &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
ops.def( ops.def(
...@@ -741,25 +741,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -741,25 +741,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); // ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
// ops.def( ops.def(
// "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
// "()"); "()");
// ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def( ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> " "-> "
// "()"); "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def( ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> " "Tensor! scale, Tensor? scale_ub) -> "
// "()"); "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant); &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
ops.def( ops.def(
......
# Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型
## Overview
Medusa是一种大模型并行解码算法,除了支持官方提供的Top1-proposer,我们还支持tree-style并行解码,target model和draft model均可多卡推理
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
## Support Matrix
* FP16
* BF16
* PAGED_KV_CACHE
* Tensor Parallel
### convert Medusa model weights
# medusa 模型需要转换为vllm中Medusa的模型格式
```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
```
此处model.bin是训练后保存的medusa head权重,如果希望采用Top1-proposer,medusa_choices可以不设置
### Run tree-style generation server
```bash
VLLM_TREE_DECODING=1 python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 9 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\
--num-speculative-heads 4 --num-speculative-tokens 24
```
注意:
num_speculative_tokens = len(medusa_choices) + 1
medusa_choices个数不能太多,否则多batch下会降低推理速度
speculative-disable-by-batch-size要大于max-num-seqs,否则当batch等于max-num-seqs时,不会走并行解码
### Run Top1-proposer server
python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 9 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\
--num-speculative-tokens 4
注意:
使用Top1-proposer时,num-speculative-tokens就是medusa head的个数
# do request
```bash
curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen_medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```
### Run tree-style benchmark
```bash
VLLM_TREE_DECODING=1 python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-heads 4 --num-speculative-tokens 24 --gpu-memory-utilization 0.95
```
### Run Top1-proposer benchmark
```bash
python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-tokens 4 --gpu-memory-utilization 0.95
```
可设置max-num-seqs对不同的batch进行性能测试
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import numpy as np
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
from vllm.lora.request import LoRARequest
def nullable_str(val: str):
if not val or val == "None":
return None
return val
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Only keep the first two turns of each conversation.
dataset = [data["prompt"] for data in dataset]
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def run_vllm(
warmup_requests: List[Tuple[str, int, int]],
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
# warmup
warmup_prompts = []
warmup_sampling_params = []
for prompt, _, output_len in warmup_requests:
warmup_prompts.append(prompt)
warmup_sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
if lora_modules is None:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
else:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
total_out_tokens = 0
start = time.perf_counter()
if lora_modules is None:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
else:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
for output in outputs:
print("token_ids len:{} text:{}".format(len(output.outputs[0].token_ids), output.outputs[0].text))
total_out_tokens += len(output.outputs[0].token_ids)
end = time.perf_counter()
return end - start, total_out_tokens
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
out_dict = {}
async for i, res in all_gens:
#print("res:", res)
out_dict[res.request_id] = len(res.outputs[0].token_ids)
end = time.perf_counter()
total_out_tokens = 0
for token_num in out_dict.values():
total_out_tokens += token_num
return end - start, total_out_tokens
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.async_engine:
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, False, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.lora_extra_vocab_size,
args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens
]
else:
run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.lora_extra_vocab_size,
args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time, total_out_tokens = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time, total_out_tokens = run_vllm(*run_args, args.use_new_beam_search_impl, args.lora_modules)
total_num_tokens = total_out_tokens + sum(prompt_len
for _, prompt_len, _ in requests)
print(f"Latency: {elapsed_time:.2f} s")
print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=256,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
default="auto",
choices=DEVICE_OPTIONS,
help='device type for vLLM execution')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument('--merge-lora',
type=bool,
default=False,
help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument('--lora-target-modules',
nargs='*',
default=None,
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
parser.add_argument(
'--num-speculative-heads',
type=int,
default=EngineArgs.num_speculative_heads,
help='The number of speculative heads to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--lora-modules',
type=nullable_str,
default=None,
help=
'Path of lora model.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
main(args)
\ No newline at end of file
import os
import ast
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
from addict import Dict
import yaml
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from safetensors.torch import save_model, safe_open
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight'
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
hidden_size: int = 4096,
vocab_size: int = 32001,
num_heads: int = 5,
num_hidden_layers: int = 1,
max_paths: int = 64,
topk: int = 10,
truncated_vocab_size: Optional[int] = None,
**kwargs):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.max_paths = max_paths
self.topk = topk
self.max_seq_len = int(2**20)
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
else truncated_vocab_size
if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"]
super().__init__(**kwargs)
@property
def num_attention_heads(self):
return 0
@property
def num_lookahead_tokens(self):
return self.num_heads
@num_lookahead_tokens.setter
def num_lookahead_tokens(self, num_lookahead_tokens: int):
self.num_heads = num_lookahead_tokens
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.embedding_dim = embedding_dim
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_padded],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
def forward(self, input_):
masked_input = input_
# Get the embeddings.
output = F.embedding(masked_input.long(), self.weight)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers)
])
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x
class Medusa(nn.Module):
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
weights_map = {}
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")
if name == "token_map":
if self.truncated_vocab_size < self.orig_vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
loaded_weight.shape[0] > self.token_map.shape[0]:
loaded_weight = loaded_weight[self.token_map]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None)
class CustomMedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
name_or_path: str = "S-3000/vllm-medusa-qwen1.5-7b-chat",
architectures: list[str] = ["MedusaModel"],
hidden_size: int = 4096,
model_type: str = "medusa",
num_heads: int = 5,
num_hidden_layers: int = 1,
transformers_version: str = "4.41.2",
truncated_vocab_size: Optional[int] = None,
vocab_size: int = 151936,
medusa_choices:List[List[int]] = None,
**kwargs):
super().__init__(**kwargs)
self._name_or_path = name_or_path
self.architectures = architectures
self.hidden_size = hidden_size
self.model_type = model_type
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.transformers_version = transformers_version
self.truncated_vocab_size = truncated_vocab_size
self.vocab_size = vocab_size
self.medusa_choices = medusa_choices
def main(args):
medusa_head_num = args.medusa_num_heads
medusa_num_layers = args.medusa_num_layers
config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num)
medusa_model = Medusa(config)
params_dict = dict(medusa_model.named_parameters())
trained_medusa_model = torch.load(args.medusa_model_path)
for i in range(medusa_head_num):
vllm_medusa_head_weight_name = VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE.format(i)
trained_medusa_head_weight_name = TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE.format(i)
vllm_medusa_head_param = params_dict[vllm_medusa_head_weight_name]
trained_medusa_head_param = trained_medusa_model[trained_medusa_head_weight_name]
weight_loader = getattr(vllm_medusa_head_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_head_param, trained_medusa_head_param)
for i in range(medusa_head_num):
for j in range(medusa_num_layers):
# load linear weight
vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_param = params_dict[vllm_medusa_block_weight_name]
trained_medusa_block_param = trained_medusa_model[trained_medusa_block_weight_name]
weight_loader = getattr(vllm_medusa_block_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_param, trained_medusa_block_param)
# load linear bias
vllm_medusa_block_bias_name = VLLM_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
trained_medusa_block_bias_name = TRAINED_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_bias_param = params_dict[vllm_medusa_block_bias_name]
trained_medusa_block_bias_param = trained_medusa_model[trained_medusa_block_bias_name]
weight_loader = getattr(vllm_medusa_block_bias_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_bias_param, trained_medusa_block_bias_param)
if not Path(args.output_dir).is_dir():
os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
medusa_choices = ast.literal_eval(args.medusa_choices) if args.medusa_choices is not None else None
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size,
num_heads=medusa_head_num,
num_hidden_layers=medusa_num_layers,
vocab_size=args.vocab_size,
medusa_choices=medusa_choices)
to_save_config.save_pretrained(args.output_dir)
# validate weight
# with safe_open(os.path.join(args.output_dir, "model.safetensors"), framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
parser.add_argument("--medusa_model_path", type=str, required=True,
help="Path to the medusa model file.")
parser.add_argument("--vocab_size", type=int, required=True,
help="Vocab size")
parser.add_argument("--medusa_num_heads", type=int, required=True,
help="Number of Medusa heads")
parser.add_argument("--medusa_num_layers", type=int, required=True,
help="Number of Medusa layers")
parser.add_argument("--hidden_size", type=int, required=True,
help="Hidden size")
parser.add_argument("--output_dir", type=str, required=True,
help="Output dir")
parser.add_argument(
'--medusa_choices',
type=str,
default=None,
help="Medusa choice to use, if not none, will use Medusa decoding."
" E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens."
)
args = parser.parse_args()
main(args)
...@@ -901,20 +901,20 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, ...@@ -901,20 +901,20 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# return torch.empty_like(b, memory_format=torch.contiguous_format) # return torch.empty_like(b, memory_format=torch.contiguous_format)
# if hasattr(torch.ops._C, "allspark_w8a16_gemm"): if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
# @register_fake("_C::allspark_w8a16_gemm") @register_fake("_C::allspark_w8a16_gemm")
# def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor,
# b_scales: torch.Tensor, b_scales: torch.Tensor,
# b_qzeros: Optional[torch.Tensor], b_qzeros: Optional[torch.Tensor],
# n: torch.SymInt, group_size: torch.SymInt, n: torch.SymInt, group_size: torch.SymInt,
# sm_count: torch.SymInt, sm_count: torch.SymInt,
# sm_version: torch.SymInt, sm_version: torch.SymInt,
# CUBLAS_M_THRESHOLD: torch.SymInt, CUBLAS_M_THRESHOLD: torch.SymInt,
# has_zp: bool, has_zp: bool,
# n32k16_reorder: bool) -> torch.Tensor: n32k16_reorder: bool) -> torch.Tensor:
# m = a.size(0) m = a.size(0)
# return torch.empty((m, n), device=a.device, dtype=a.dtype) return torch.empty((m, n), device=a.device, dtype=a.dtype)
if hasattr(torch.ops._C, "ggml_dequantize"): if hasattr(torch.ops._C, "ggml_dequantize"):
...@@ -1664,67 +1664,66 @@ def scaled_fp4_experts_quant( ...@@ -1664,67 +1664,66 @@ def scaled_fp4_experts_quant(
return output, output_scales return output, output_scales
# fp8 def scaled_fp8_quant(
# def scaled_fp8_quant( input: torch.Tensor,
# input: torch.Tensor, scale: Optional[torch.Tensor] = None,
# scale: Optional[torch.Tensor] = None, num_token_padding: Optional[int] = None,
# num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None,
# scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False,
# use_per_token_if_dynamic: bool = False, output: Optional[torch.Tensor] = None,
# output: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]:
# ) -> tuple[torch.Tensor, torch.Tensor]: """
# """ Quantize input tensor to FP8 and return quantized tensor and scale.
# Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
# This function supports both static and dynamic quantization: If you provide the scale, it will use static scaling and if you omit it,
# provide the scale, it will use static scaling and if you omit it, the scale will be determined dynamically. The function also allows
# the scale will be determined dynamically. The function also allows optional padding of the output tensors for downstream kernels that
# optional padding of the output tensors for downstream kernels that will benefit from padding.
# will benefit from padding.
Args:
# Args: input: The input tensor to be quantized to FP8
# input: The input tensor to be quantized to FP8 scale: Optional scaling factor for the FP8 quantization
# scale: Optional scaling factor for the FP8 quantization scale_ub: Optional upper bound for scaling factor in dynamic
# scale_ub: Optional upper bound for scaling factor in dynamic per token case
# per token case num_token_padding: If specified, pad the first dimension
# num_token_padding: If specified, pad the first dimension of the output to at least this value.
# of the output to at least this value. use_per_token_if_dynamic: Whether to do per_tensor or per_token
# use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case.
# in the dynamic quantization case.
Returns:
# Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor.
# scaling factor. """
# """ # This code assumes batch_dim and num_tokens are flattened
# # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2)
# assert (input.ndim == 2) shape: Union[tuple[int, int], torch.Size] = input.shape
# shape: Union[tuple[int, int], torch.Size] = input.shape # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = current_platform.fp8_dtype()
# out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding:
# if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1])
# shape = (max(num_token_padding, input.shape[0]), shape[1]) if output is None:
# if output is None: output = torch.empty(shape, device=input.device, dtype=out_dtype)
# output = torch.empty(shape, device=input.device, dtype=out_dtype) else:
# else: assert num_token_padding is None, \
# assert num_token_padding is None, \ "padding not supported if output passed in"
# "padding not supported if output passed in" assert output.dtype == out_dtype
# assert output.dtype == out_dtype
if scale is None:
# if scale is None: if use_per_token_if_dynamic:
# if use_per_token_if_dynamic: scale = torch.empty((shape[0], 1),
# scale = torch.empty((shape[0], 1), device=input.device,
# device=input.device, dtype=torch.float32)
# dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# torch.ops._C.dynamic_per_token_scaled_fp8_quant( output, input, scale, scale_ub)
# output, input, scale, scale_ub) else:
# else: scale = torch.empty(1, device=input.device, dtype=torch.float32)
# scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else:
# else: assert scale.numel() == 1, f"{scale.shape}"
# assert scale.numel() == 1, f"{scale.shape}" torch.ops._C.static_scaled_fp8_quant(output, input, scale)
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
# return output, scale
# gptq allspark # gptq allspark
......
...@@ -26,14 +26,14 @@ FP4_DTYPE = torch.uint8 ...@@ -26,14 +26,14 @@ FP4_DTYPE = torch.uint8
SILU_MUL_OP = torch.ops._C.silu_and_mul.default SILU_MUL_OP = torch.ops._C.silu_and_mul.default
# FUSED_OPS: dict[QuantKey, OpOverload] = { FUSED_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
# } }
# silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr( silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant")) torch.ops._C, "silu_and_mul_nvfp4_quant"))
# if silu_and_mul_nvfp4_quant_supported: if silu_and_mul_nvfp4_quant_supported:
# FUSED_OPS[ FUSED_OPS[
# kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
class ActivationQuantPattern(ABC): class ActivationQuantPattern(ABC):
......
...@@ -68,15 +68,15 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -68,15 +68,15 @@ class FixFunctionalizationPass(VllmInductorPass):
elif at_target == torch.ops._C.fused_add_rms_norm.default: elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'} mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'residual'} mutated_args = {1: 'result', 2: 'residual'}
# self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
elif at_target in [ elif at_target in [
torch.ops._C.rms_norm.default, torch.ops._C.rms_norm.default,
# torch.ops._C.rms_norm_static_fp8_quant.default, torch.ops._C.rms_norm_static_fp8_quant.default,
]: ]:
mutated_args = {1: 'result'} mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
...@@ -89,12 +89,12 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -89,12 +89,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node, node,
mutated_args, mutated_args,
args=('result', 'input')) args=('result', 'input'))
# elif at_target == torch.ops._C.silu_and_mul_quant.default: elif at_target == torch.ops._C.silu_and_mul_quant.default:
# mutated_args = {1: 'result'} mutated_args = {1: 'result'}
# self.defunctionalize(graph, self.defunctionalize(graph,
# node, node,
# mutated_args, mutated_args,
# args=('result', 'input', 'scale')) args=('result', 'input', 'scale'))
# elif hasattr( # elif hasattr(
# torch.ops._C, "silu_and_mul_nvfp4_quant" # torch.ops._C, "silu_and_mul_nvfp4_quant"
# ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: # ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default:
......
...@@ -40,12 +40,12 @@ RMS_OP = torch.ops._C.rms_norm.default ...@@ -40,12 +40,12 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
# kFp8StaticTensorSym: kFp8StaticTensorSym:
# torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
# kFp8DynamicTensorSym: kFp8DynamicTensorSym:
# torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
# kFp8DynamicTokenSym: kFp8DynamicTokenSym:
# torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
...@@ -66,14 +66,14 @@ class FusedRMSQuantKey(NamedTuple): ...@@ -66,14 +66,14 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
# FusedRMSQuantKey(kFp8StaticTensorSym, False): FusedRMSQuantKey(kFp8StaticTensorSym, False):
# torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8StaticTensorSym, True): FusedRMSQuantKey(kFp8StaticTensorSym, True):
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8DynamicTokenSym, False): FusedRMSQuantKey(kFp8DynamicTokenSym, False):
# torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
# FusedRMSQuantKey(kFp8DynamicTokenSym, True): FusedRMSQuantKey(kFp8DynamicTokenSym, True):
# torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
} }
...@@ -351,22 +351,22 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): ...@@ -351,22 +351,22 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
self.patterns: PatternMatcherPass = PatternMatcherPass( self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass") pass_name="rmsnorm_quant_fusion_pass")
# for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant # Fuse rms_norm + static fp8 quant
# RMSNormStaticQuantPattern(epsilon, RMSNormStaticQuantPattern(epsilon,
# FP8_DTYPE).register(self.patterns) FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant # Fuse fused_add_rms_norm + static fp8 quant
# FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns) self.patterns)
# # Fuse rms_norm + dynamic per-token fp8 quant # # Fuse rms_norm + dynamic per-token fp8 quant
# RMSNormDynamicQuantPattern(epsilon, RMSNormDynamicQuantPattern(epsilon,
# FP8_DTYPE).register(self.patterns) FP8_DTYPE).register(self.patterns)
# # Fuse fused_add_rms_norm + dynamic per-token fp8 quant # # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
# FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
# self.patterns) self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
......
...@@ -446,16 +446,16 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -446,16 +446,16 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
for epsilon in [1e-5, 1e-6]: for epsilon in [1e-5, 1e-6]:
# RMSNorm + Static FP8 quantization patterns # RMSNorm + Static FP8 quantization patterns
# fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
# FirstAllReduceRMSNormStaticFP8Pattern( FirstAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device, epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns) fp8_quant_op).register(self.patterns)
# MiddleAllReduceRMSNormStaticFP8Pattern( MiddleAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device, epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns) fp8_quant_op).register(self.patterns)
# LastAllReduceRMSNormStaticFP8Pattern( LastAllReduceRMSNormStaticFP8Pattern(
# epsilon, self.model_dtype, self.device, epsilon, self.model_dtype, self.device,
# fp8_quant_op).register(self.patterns) fp8_quant_op).register(self.patterns)
# Normal RMSNorm patterns # Normal RMSNorm patterns
FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, FirstAllReduceRMSNormPattern(epsilon, self.model_dtype,
......
...@@ -214,7 +214,6 @@ if TYPE_CHECKING: ...@@ -214,7 +214,6 @@ if TYPE_CHECKING:
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_TREE_DECODING: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
...@@ -1545,12 +1544,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1545,12 +1544,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")), ("true", "1")),
# If set, vLLM will use tree-style speculative decoding.
"VLLM_TREE_DECODING":
lambda:
(os.environ.get("VLLM_TREE_DECODING", "0").strip().lower() in
("1", "true")),
# If set, vLLM will disable the draft model in cudagraph mode. # If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER": "VLLM_SPEC_DECODE_EAGER":
lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))), lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))),
......
...@@ -141,8 +141,6 @@ class LLMEngine: ...@@ -141,8 +141,6 @@ class LLMEngine:
# Don't keep the dummy data in memory # Don't keep the dummy data in memory
self.reset_mm_cache() self.reset_mm_cache()
# self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1'
@classmethod @classmethod
def from_vllm_config( def from_vllm_config(
cls, cls,
......
...@@ -52,8 +52,6 @@ class WorkerBase: ...@@ -52,8 +52,6 @@ class WorkerBase:
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers. communicate request metadata to other workers.
""" """
# TODO
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment