Commit 34c31a8d authored by laibao's avatar laibao
Browse files

Update README.md to include detailed information about GLM-4V-9B, its...

Update README.md to include detailed information about GLM-4V-9B, its capabilities, model structure, algorithms, environment setup, inference instructions, and application scenarios.
parent e6dcd9bd
Pipeline #2981 canceled with stages
from itertools import accumulate
from typing import List, Optional
import nvtx
import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
from vllm.utils import FlexibleArgumentParser, seed_everything
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"factor": tuple(scaling_factors)
})
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes: List[RotaryEmbedding] = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
{
"type": "linear",
"factor": (scaling_factor, )
}))
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map = torch.tensor(
list(
accumulate([0] + [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
# map query types to offsets
query_offsets = offset_map[query_types]
# the kernel takes flattened offsets
flatten_offsets = query_offsets.flatten()
# batched queries of the same type together for non-batched RoPE
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
if __name__ == '__main__':
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
type=str,
choices=["bfloat16", "float"],
default="float")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0")
args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora(
is_neox_style=args.is_neox_style,
batch_size=args.batch_size,
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
)
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}
import math
import pickle
import re
from collections import defaultdict
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('filename', type=str)
args = parser.parse_args()
with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
results = defaultdict(lambda: list())
for v in data:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
else:
raise Exception("MKN not found")
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
if result is not None:
M = result.group(1)
else:
raise Exception("MKN not found")
kernel = v.task_spec.description
results[KN].append({
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten()
for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
x="batch_size",
y="median",
hue="kernel",
style="kernel",
markers=True,
dashes=False,
palette="Dark2")
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
plt.tight_layout()
plt.savefig("graph_machete_bench.pdf")
pandas
\ No newline at end of file
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
#!/bin/bash
PORT=8000
MODEL=$1
TOKENS=$2
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
--model-id $MODEL \
--sharded false \
--max-input-length 1024 \
--max-total-tokens 2048 \
--max-best-of 5 \
--max-concurrent-requests 5000 \
--max-batch-total-tokens $TOKENS
import cProfile
import pstats
from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
FROM fairest creatures we desire increase,
That thereby beauty's rose might never die,
But as the riper should by time decease,
His tender heir might bear his memory:
But thou, contracted to thine own bright eyes,
Feed'st thy light'st flame with self-substantial fuel,
Making a famine where abundance lies,
Thyself thy foe, to thy sweet self too cruel.
Thou that art now the world's fresh ornament
And only herald to the gaudy spring,
Within thine own bud buriest thy content
And, tender churl, makest waste in niggarding.
Pity the world, or else this glutton be,
To eat the world's due, by the grave and thee.
When forty winters shall beseige thy brow,
And dig deep trenches in thy beauty's field,
Thy youth's proud livery, so gazed on now,
Will be a tatter'd weed, of small worth held:
Then being ask'd where all thy beauty lies,
Where all the treasure of thy lusty days,
To say, within thine own deep-sunken eyes,
Were an all-eating shame and thriftless praise.
How much more praise deserved thy beauty's use,
If thou couldst answer 'This fair child of mine
Shall sum my count and make my old excuse,'
Proving his beauty by succession thine!
This were to be new made when thou art old,
And see thy blood warm when thou feel'st it cold.
Look in thy glass, and tell the face thou viewest
Now is the time that face should form another;
Whose fresh repair if now thou not renewest,
Thou dost beguile the world, unbless some mother.
For where is she so fair whose unear'd womb
Disdains the tillage of thy husbandry?
Or who is he so fond will be the tomb
Of his self-love, to stop posterity?
Thou art thy mother's glass, and she in thee
Calls back the lovely April of her prime:
So thou through windows of thine age shall see
Despite of wrinkles this thy golden time.
But if thou live, remember'd not to be,
Die single, and thine image dies with thee.
Unthrifty loveliness, why dost thou spend
Upon thyself thy beauty's legacy?
Nature's bequest gives nothing but doth lend,
And being frank she lends to those are free.
Then, beauteous niggard, why dost thou abuse
The bounteous largess given thee to give?
Profitless usurer, why dost thou use
So great a sum of sums, yet canst not live?
For having traffic with thyself alone,
Thou of thyself thy sweet self dost deceive.
Then how, when nature calls thee to be gone,
What acceptable audit canst thou leave?
Thy unused beauty must be tomb'd with thee,
Which, used, lives th' executor to be.
Those hours, that with gentle work did frame
The lovely gaze where every eye doth dwell,
Will play the tyrants to the very same
And that unfair which fairly doth excel:
For never-resting time leads summer on
To hideous winter and confounds him there;
Sap cheque'd with frost and lusty leaves quite gone,
Beauty o'ersnow'd and bareness every where:
Then, were not summer's distillation left,
A liquid prisoner pent in walls of glass,
Beauty's effect with beauty were bereft,
Nor it nor no remembrance what it was:
But flowers distill'd though they with winter meet,
Leese but their show; their substance still lives sweet.
Then let not winter's ragged hand deface
In thee thy summer, ere thou be distill'd:
Make sweet some vial; treasure thou some place
With beauty's treasure, ere it be self-kill'd.
That use is not forbidden usury,
Which happies those that pay the willing loan;
That's for thyself to breed another thee,
Or ten times happier, be it ten for one;
Ten times thyself were happier than thou art,
If ten of thine ten times refigured thee:
Then what could death do, if thou shouldst depart,
Leaving thee living in posterity?
Be not self-will'd, for thou art much too fair
To be death's conquest and make worms thine heir.
Lo! in the orient when the gracious light
Lifts up his burning head, each under eye
Doth homage to his new-appearing sight,
Serving with looks his sacred majesty;
And having climb'd the steep-up heavenly hill,
Resembling strong youth in his middle age,
yet mortal looks adore his beauty still,
Attending on his golden pilgrimage;
But when from highmost pitch, with weary car,
Like feeble age, he reeleth from the day,
The eyes, 'fore duteous, now converted are
From his low tract and look another way:
So thou, thyself out-going in thy noon,
Unlook'd on diest, unless thou get a son.
Music to hear, why hear'st thou music sadly?
Sweets with sweets war not, joy delights in joy.
Why lovest thou that which thou receivest not gladly,
Or else receivest with pleasure thine annoy?
If the true concord of well-tuned sounds,
By unions married, do offend thine ear,
They do but sweetly chide thee, who confounds
In singleness the parts that thou shouldst bear.
Mark how one string, sweet husband to another,
Strikes each in each by mutual ordering,
Resembling sire and child and happy mother
Who all in one, one pleasing note do sing:
Whose speechless song, being many, seeming one,
Sings this to thee: 'thou single wilt prove none.'
Is it for fear to wet a widow's eye
That thou consumest thyself in single life?
Ah! if thou issueless shalt hap to die.
The world will wail thee, like a makeless wife;
The world will be thy widow and still weep
That thou no form of thee hast left behind,
When every private widow well may keep
By children's eyes her husband's shape in mind.
Look, what an unthrift in the world doth spend
Shifts but his place, for still the world enjoys it;
But beauty's waste hath in the world an end,
And kept unused, the user so destroys it.
No love toward others in that bosom sits
That on himself such murderous shame commits.
For shame! deny that thou bear'st love to any,
Who for thyself art so unprovident.
Grant, if thou wilt, thou art beloved of many,
But that thou none lovest is most evident;
For thou art so possess'd with murderous hate
That 'gainst thyself thou stick'st not to conspire.
Seeking that beauteous roof to ruinate
Which to repair should be thy chief desire.
O, change thy thought, that I may change my mind!
Shall hate be fairer lodged than gentle love?
Be, as thy presence is, gracious and kind,
Or to thyself at least kind-hearted prove:
Make thee another self, for love of me,
That beauty still may live in thine or thee.
As fast as thou shalt wane, so fast thou growest
In one of thine, from that which thou departest;
And that fresh blood which youngly thou bestowest
Thou mayst call thine when thou from youth convertest.
Herein lives wisdom, beauty and increase:
Without this, folly, age and cold decay:
If all were minded so, the times should cease
And threescore year would make the world away.
Let those whom Nature hath not made for store,
Harsh featureless and rude, barrenly perish:
Look, whom she best endow'd she gave the more;
Which bounteous gift thou shouldst in bounty cherish:
She carved thee for her seal, and meant thereby
Thou shouldst print more, not let that copy die.
When I do count the clock that tells the time,
And see the brave day sunk in hideous night;
When I behold the violet past prime,
And sable curls all silver'd o'er with white;
When lofty trees I see barren of leaves
Which erst from heat did canopy the herd,
And summer's green all girded up in sheaves
Borne on the bier with white and bristly beard,
Then of thy beauty do I question make,
That thou among the wastes of time must go,
Since sweets and beauties do themselves forsake
And die as fast as they see others grow;
And nothing 'gainst Time's scythe can make defence
Save breed, to brave him when he takes thee hence.
O, that you were yourself! but, love, you are
No longer yours than you yourself here live:
Against this coming end you should prepare,
And your sweet semblance to some other give.
So should that beauty which you hold in lease
Find no determination: then you were
Yourself again after yourself's decease,
When your sweet issue your sweet form should bear.
Who lets so fair a house fall to decay,
Which husbandry in honour might uphold
Against the stormy gusts of winter's day
And barren rage of death's eternal cold?
O, none but unthrifts! Dear my love, you know
You had a father: let your son say so.
Not from the stars do I my judgment pluck;
And yet methinks I have astronomy,
But not to tell of good or evil luck,
Of plagues, of dearths, or seasons' quality;
Nor can I fortune to brief minutes tell,
Pointing to each his thunder, rain and wind,
Or say with princes if it shall go well,
By oft predict that I in heaven find:
But from thine eyes my knowledge I derive,
And, constant stars, in them I read such art
As truth and beauty shall together thrive,
If from thyself to store thou wouldst convert;
Or else of thee this I prognosticate:
Thy end is truth's and beauty's doom and date.
When I consider every thing that grows
Holds in perfection but a little moment,
That this huge stage presenteth nought but shows
Whereon the stars in secret influence comment;
When I perceive that men as plants increase,
Cheered and cheque'd even by the self-same sky,
Vaunt in their youthful sap, at height decrease,
And wear their brave state out of memory;
Then the conceit of this inconstant stay
Sets you most rich in youth before my sight,
Where wasteful Time debateth with Decay,
To change your day of youth to sullied night;
And all in war with Time for love of you,
As he takes from you, I engraft you new.
But wherefore do not you a mightier way
Make war upon this bloody tyrant, Time?
And fortify yourself in your decay
With means more blessed than my barren rhyme?
Now stand you on the top of happy hours,
And many maiden gardens yet unset
With virtuous wish would bear your living flowers,
Much liker than your painted counterfeit:
So should the lines of life that life repair,
Which this, Time's pencil, or my pupil pen,
Neither in inward worth nor outward fair,
Can make you live yourself in eyes of men.
To give away yourself keeps yourself still,
And you must live, drawn by your own sweet skill.
Who will believe my verse in time to come,
If it were fill'd with your most high deserts?
Though yet, heaven knows, it is but as a tomb
Which hides your life and shows not half your parts.
If I could write the beauty of your eyes
And in fresh numbers number all your graces,
The age to come would say 'This poet lies:
Such heavenly touches ne'er touch'd earthly faces.'
So should my papers yellow'd with their age
Be scorn'd like old men of less truth than tongue,
And your true rights be term'd a poet's rage
And stretched metre of an antique song:
But were some child of yours alive that time,
You should live twice; in it and in my rhyme.
Shall I compare thee to a summer's day?
Thou art more lovely and more temperate:
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date:
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimm'd;
And every fair from fair sometime declines,
By chance or nature's changing course untrimm'd;
But thy eternal summer shall not fade
Nor lose possession of that fair thou owest;
Nor shall Death brag thou wander'st in his shade,
When in eternal lines to time thou growest:
So long as men can breathe or eyes can see,
So long lives this and this gives life to thee.
Devouring Time, blunt thou the lion's paws,
And make the earth devour her own sweet brood;
Pluck the keen teeth from the fierce tiger's jaws,
And burn the long-lived phoenix in her blood;
Make glad and sorry seasons as thou fleets,
And do whate'er thou wilt, swift-footed Time,
To the wide world and all her fading sweets;
But I forbid thee one most heinous crime:
O, carve not with thy hours my love's fair brow,
Nor draw no lines there with thine antique pen;
Him in thy course untainted do allow
For beauty's pattern to succeeding men.
Yet, do thy worst, old Time: despite thy wrong,
My love shall in my verse ever live young.
A woman's face with Nature's own hand painted
Hast thou, the master-mistress of my passion;
A woman's gentle heart, but not acquainted
With shifting change, as is false women's fashion;
An eye more bright than theirs, less false in rolling,
Gilding the object whereupon it gazeth;
A man in hue, all 'hues' in his controlling,
Much steals men's eyes and women's souls amazeth.
And for a woman wert thou first created;
Till Nature, as she wrought thee, fell a-doting,
And by addition me of thee defeated,
By adding one thing to my purpose nothing.
But since she prick'd thee out for women's pleasure,
Mine be thy love and thy love's use their treasure.
So is it not with me as with that Muse
Stirr'd by a painted beauty to his verse,
Who heaven itself for ornament doth use
And every fair with his fair doth rehearse
Making a couplement of proud compare,
With sun and moon, with earth and sea's rich gems,
With April's first-born flowers, and all things rare
That heaven's air in this huge rondure hems.
O' let me, true in love, but truly write,
And then believe me, my love is as fair
As any mother's child, though not so bright
As those gold candles fix'd in heaven's air:
Let them say more than like of hearsay well;
I will not praise that purpose not to sell.
My glass shall not persuade me I am old,
So long as youth and thou are of one date;
But when in thee time's furrows I behold,
Then look I death my days should expiate.
For all that beauty that doth cover thee
Is but the seemly raiment of my heart,
Which in thy breast doth live, as thine in me:
How can I then be elder than thou art?
O, therefore, love, be of thyself so wary
As I, not for myself, but for thee will;
Bearing thy heart, which I will keep so chary
As tender nurse her babe from faring ill.
Presume not on thy heart when mine is slain;
Thou gavest me thine, not to give back again.
As an unperfect actor on the stage
Who with his fear is put besides his part,
Or some fierce thing replete with too much rage,
Whose strength's abundance weakens his own heart.
So I, for fear of trust, forget to say
The perfect ceremony of love's rite,
And in mine own love's strength seem to decay,
O'ercharged with burden of mine own love's might.
O, let my books be then the eloquence
And dumb presagers of my speaking breast,
Who plead for love and look for recompense
More than that tongue that more hath more express'd.
O, learn to read what silent love hath writ:
To hear with eyes belongs to love's fine wit.
Mine eye hath play'd the painter and hath stell'd
Thy beauty's form in table of my heart;
My body is the frame wherein 'tis held,
And perspective it is the painter's art.
For through the painter must you see his skill,
To find where your true image pictured lies;
Which in my bosom's shop is hanging still,
That hath his windows glazed with thine eyes.
Now see what good turns eyes for eyes have done:
Mine eyes have drawn thy shape, and thine for me
Are windows to my breast, where-through the sun
Delights to peep, to gaze therein on thee;
Yet eyes this cunning want to grace their art;
They draw but what they see, know not the heart.
Let those who are in favour with their stars
Of public honour and proud titles boast,
Whilst I, whom fortune of such triumph bars,
Unlook'd for joy in that I honour most.
Great princes' favourites their fair leaves spread
But as the marigold at the sun's eye,
And in themselves their pride lies buried,
For at a frown they in their glory die.
The painful warrior famoused for fight,
After a thousand victories once foil'd,
Is from the book of honour razed quite,
And all the rest forgot for which he toil'd:
Then happy I, that love and am beloved
Where I may not remove nor be removed.
Lord of my love, to whom in vassalage
Thy merit hath my duty strongly knit,
To thee I send this written embassage,
To witness duty, not to show my wit:
Duty so great, which wit so poor as mine
May make seem bare, in wanting words to show it,
But that I hope some good conceit of thine
In thy soul's thought, all naked, will bestow it;
Till whatsoever star that guides my moving
Points on me graciously with fair aspect
And puts apparel on my tatter'd loving,
To show me worthy of thy sweet respect:
Then may I dare to boast how I do love thee;
Till then not show my head where thou mayst prove me.
Weary with toil, I haste me to my bed,
The dear repose for limbs with travel tired;
But then begins a journey in my head,
To work my mind, when body's work's expired:
For then my thoughts, from far where I abide,
Intend a zealous pilgrimage to thee,
And keep my drooping eyelids open wide,
Looking on darkness which the blind do see
Save that my soul's imaginary sight
Presents thy shadow to my sightless view,
Which, like a jewel hung in ghastly night,
Makes black night beauteous and her old face new.
Lo! thus, by day my limbs, by night my mind,
For thee and for myself no quiet find.
How can I then return in happy plight,
That am debarr'd the benefit of rest?
When day's oppression is not eased by night,
But day by night, and night by day, oppress'd?
And each, though enemies to either's reign,
Do in consent shake hands to torture me;
The one by toil, the other to complain
How far I toil, still farther off from thee.
I tell the day, to please them thou art bright
And dost him grace when clouds do blot the heaven:
So flatter I the swart-complexion'd night,
When sparkling stars twire not thou gild'st the even.
But day doth daily draw my sorrows longer
And night doth nightly make grief's strength seem stronger.
When, in disgrace with fortune and men's eyes,
I all alone beweep my outcast state
And trouble deal heaven with my bootless cries
And look upon myself and curse my fate,
Wishing me like to one more rich in hope,
Featured like him, like him with friends possess'd,
Desiring this man's art and that man's scope,
With what I most enjoy contented least;
Yet in these thoughts myself almost despising,
Haply I think on thee, and then my state,
Like to the lark at break of day arising
From sullen earth, sings hymns at heaven's gate;
For thy sweet love remember'd such wealth brings
That then I scorn to change my state with kings.
When to the sessions of sweet silent thought
I summon up remembrance of things past,
I sigh the lack of many a thing I sought,
And with old woes new wail my dear time's waste:
Then can I drown an eye, unused to flow,
For precious friends hid in death's dateless night,
And weep afresh love's long since cancell'd woe,
And moan the expense of many a vanish'd sight:
Then can I grieve at grievances foregone,
And heavily from woe to woe tell o'er
The sad account of fore-bemoaned moan,
Which I new pay as if not paid before.
But if the while I think on thee, dear friend,
All losses are restored and sorrows end.
Thy bosom is endeared with all hearts,
Which I by lacking have supposed dead,
And there reigns love and all love's loving parts,
And all those friends which I thought buried.
How many a holy and obsequious tear
Hath dear religious love stol'n from mine eye
As interest of the dead, which now appear
But things removed that hidden in thee lie!
Thou art the grave where buried love doth live,
Hung with the trophies of my lovers gone,
Who all their parts of me to thee did give;
That due of many now is thine alone:
Their images I loved I view in thee,
And thou, all they, hast all the all of me.
If thou survive my well-contented day,
When that churl Death my bones with dust shall cover,
And shalt by fortune once more re-survey
These poor rude lines of thy deceased lover,
Compare them with the bettering of the time,
And though they be outstripp'd by every pen,
Reserve them for my love, not for their rhyme,
Exceeded by the height of happier men.
O, then vouchsafe me but this loving thought:
'Had my friend's Muse grown with this growing age,
A dearer birth than this his love had brought,
To march in ranks of better equipage:
But since he died and poets better prove,
Theirs for their style I'll read, his for his love.'
Full many a glorious morning have I seen
Flatter the mountain-tops with sovereign eye,
Kissing with golden face the meadows green,
Gilding pale streams with heavenly alchemy;
Anon permit the basest clouds to ride
With ugly rack on his celestial face,
And from the forlorn world his visage hide,
Stealing unseen to west with this disgrace:
Even so my sun one early morn did shine
With all triumphant splendor on my brow;
But out, alack! he was but one hour mine;
The region cloud hath mask'd him from me now.
Yet him for this my love no whit disdaineth;
Suns of the world may stain when heaven's sun staineth.
Why didst thou promise such a beauteous day,
And make me travel forth without my cloak,
To let base clouds o'ertake me in my way,
Hiding thy bravery in their rotten smoke?
'Tis not enough that through the cloud thou break,
To dry the rain on my storm-beaten face,
For no man well of such a salve can speak
That heals the wound and cures not the disgrace:
Nor can thy shame give physic to my grief;
Though thou repent, yet I have still the loss:
The offender's sorrow lends but weak relief
To him that bears the strong offence's cross.
Ah! but those tears are pearl which thy love sheds,
And they are rich and ransom all ill deeds.
No more be grieved at that which thou hast done:
Roses have thorns, and silver fountains mud;
Clouds and eclipses stain both moon and sun,
And loathsome canker lives in sweetest bud.
All men make faults, and even I in this,
Authorizing thy trespass with compare,
Myself corrupting, salving thy amiss,
Excusing thy sins more than thy sins are;
For to thy sensual fault I bring in sense--
Thy adverse party is thy advocate--
And 'gainst myself a lawful plea commence:
Such civil war is in my love and hate
That I an accessary needs must be
To that sweet thief which sourly robs from me.
Let me confess that we two must be twain,
Although our undivided loves are one:
So shall those blots that do with me remain
Without thy help by me be borne alone.
In our two loves there is but one respect,
Though in our lives a separable spite,
Which though it alter not love's sole effect,
Yet doth it steal sweet hours from love's delight.
I may not evermore acknowledge thee,
Lest my bewailed guilt should do thee shame,
Nor thou with public kindness honour me,
Unless thou take that honour from thy name:
But do not so; I love thee in such sort
As, thou being mine, mine is thy good report.
As a decrepit father takes delight
To see his active child do deeds of youth,
So I, made lame by fortune's dearest spite,
Take all my comfort of thy worth and truth.
For whether beauty, birth, or wealth, or wit,
Or any of these all, or all, or more,
Entitled in thy parts do crowned sit,
I make my love engrafted to this store:
So then I am not lame, poor, nor despised,
Whilst that this shadow doth such substance give
That I in thy abundance am sufficed
And by a part of all thy glory live.
Look, what is best, that best I wish in thee:
This wish I have; then ten times happy me!
\ No newline at end of file
# FP8 KV Cache
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (hcu) platforms.
## Prerequisites
- Python 3.x
- PyTorch
- NumPy
- Hugging Face Transformers
- Hugging Face Hub
- AMMO
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
1. Install all necessary prerequisites and dependencies.
2. Convert HF model into a quantized HF model.
3. Extract KV Cache Scaling Factors from quantized HF model.
4. Load KV Cache Scaling Factors into VLLM.
### 2. Convert HF model into a quantized HF model.
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`.
### 3. Extract KV Cache Scaling Factors from quantized HF model.
`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
```python
# prerequisites:
# - Quantized HF LLaMa 2 model
python3 examples/fp8/extract_scales.py --help
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
KV Scale Extraction Example
optional arguments:
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (hcu).
Optional arguments:
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
--revision: Specify the model's revision number. (Default: None)
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
```
```python
Example:
python3 examples/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
```
### 4. Load KV Cache Scaling Factors into VLLM.
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
```python
# prerequisites:
# - LLaMa 2 kv_cache_scales.json file
python3 benchmarks/benchmark_throughput.py --help
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
[--quantization-param-path KV_CACHE_quantization_param_path]
Benchmark Throughput Example
optional arguments:
-h, --help show this help message and exit
--backend {vllm,hf,mii}
--dataset DATASET Path to the dataset.
--input-len INPUT_LEN Input prompt length for each request
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
--model MODEL
--tokenizer TOKENIZER
--quantization {awq,gptq,None}, -q {awq,gptq,None}
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
--n N Number of generated sequences per prompt.
--use-beam-search
--num-prompts NUM_PROMPTS Number of prompts to process.
--seed SEED
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
--trust-remote-code trust remote code from huggingface
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
--enforce-eager enforce eager execution
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is instead supported ```for common inference criteria.
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is instead supported for common inference criteria.
```
```
Example:
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
```python
import argparse
import glob
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from safetensors.torch import safe_open
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
# Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
def _prepare_hf_weights(
quantized_model_dir: str,
load_format: str = "auto",
fall_back_to_pt: bool = True,
) -> Tuple[List[str], bool]:
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npz":
allow_patterns = ["*.npz"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(
os.path.join(quantized_model_dir, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{quantized_model_dir}`")
return hf_weights_files, use_safetensors
# Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool):
if load_format == "npz":
assert not use_safetensors
with np.load(filename) as data:
for name in data.files:
param = torch.from_numpy(data[name])
yield name, param
elif use_safetensors:
with safe_open(filename, framework="pt") as f:
for name in f.keys(): # NOQA: SIM118
param = f.get_tensor(name)
yield name, param
else:
state = torch.load(filename, map_location="cpu")
for name, param in state.items():
yield name, param
del state
torch.cuda.empty_cache()
def _kv_scales_extractor(
hf_tensor_files: List[str],
use_safetensors: bool,
rank_keyword: str = "rank",
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
"""
Given a list of files containing tensor data, attempt to extract KV cache
scales from these files. Intended as a helper function taking in the output
from _prepare_hf_weights.
Args:
rank_keyword Matches the number immediately after this keyword in the
tensor filename to determine the TP rank corresponding
to said tensor file
expected_tp_size If specified, the TP size of the tensor files is checked
against this and an error is raised if they don't match.
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
The per-rank scales are themselves represented as a dictionary of layer
indices to the respective per-layer scale.
"""
for char in rank_keyword:
assert not char.isdecimal(
), f"Rank keyword {rank_keyword} contains a numeric character!"
rank_scales_map: Dict[int, Dict[int, float]] = {}
for tensor_file in hf_tensor_files:
try:
rank_idx = tensor_file.find(rank_keyword)
if rank_idx != -1:
start_idx = rank_idx + len(rank_keyword)
stop_idx = start_idx
while stop_idx < len(
tensor_file) and tensor_file[stop_idx].isdecimal():
stop_idx += 1
if stop_idx == start_idx:
raise RuntimeError("Did not find rank # in filename.")
rank = int(tensor_file[start_idx:stop_idx])
elif len(hf_tensor_files) == 1:
# Since there is only one tensor file, we can assume
# that it's intended for TP rank 0
rank = 0
else:
raise RuntimeError(
f"Filename does not contain '{rank_keyword}'.")
except RuntimeError:
print("Unable to determine TP rank "
f"corresponding to file '{tensor_file}'")
raise
if rank not in rank_scales_map:
layer_scales_map: Dict[int, float] = {}
rank_scales_map[rank] = layer_scales_map
else:
raise RuntimeError(
f"Tensor file '{tensor_file}' shares TP rank {rank} "
"with another tensor file.")
module_delimiter = ":" if args.load_format == "npz" else "."
for name, param in _hf_tensorfile_iterator(tensor_file,
args.load_format,
use_safetensors):
if "kv_cache_scaling_factor" in name:
nums = [
int(s) for s in name.split(module_delimiter)
if s.isdecimal()
]
assert len(
nums) == 1, f"Could not determine layer idx for {name}"
layer_idx = nums[0]
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
f" factor corresponding to layer {layer_idx}"
try:
layer_scales_map[layer_idx] = param.item()
except RuntimeError:
print(
"This utility supports only per-tensor scalar scales "
f"for now. The tensor\n {name} = {param} \nis an "
"invalid scale factor.")
raise
if all(
len(layer_scales_map) == 0
for layer_scales_map in rank_scales_map.values()):
# Note: this is true even if the rank_scales_map is empty
print("WARNING: No KV cache scale factors found. No output saved.")
return None
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
if expected_tp_size is not None:
assert expected_tp_size == empirical_tp_world_size, \
f"User expected TP world size = {expected_tp_size} " \
"from model but tool is expecting TP world size = " \
f"{empirical_tp_world_size} from model instead."
for i in range(empirical_tp_world_size):
assert i in rank_scales_map, "Expected TP world size = "\
f"{empirical_tp_world_size} but did not find KV " \
f"cache scaling factors for TP rank {i}"
print(f"Found TP world size = {empirical_tp_world_size} "
"when extracting KV cache scales!")
return rank_scales_map
def _metadata_extractor(quantized_model_dir: str,
metadata_extract_fns: \
Dict[str, Callable[[Dict[str, Any]], Any]]) \
-> Dict[str, Any]:
"""
Given a directory containing quantized model files, this function
aims to extract metadata from the JSON files within this directory.
Each JSON file is expected to represent a dictionary in JSON
format (referred to as a "JSON-dictionary"). Metadata extraction is
defined by a dictionary called metadata_extract_fns, where each
metadata field name is mapped to an extraction function.
These extraction functions are designed to take a JSON-dictionary
as their only argument and return the corresponding metadata.
While extraction functions are permitted to raise exceptions, they
should only raise a KeyError or ValueError if the metadata field
cannot be extracted from the current JSON-dictionary, yet there's
a possibility of finding it in another JSON-dictionary.
The function returns a dictionary that maps metadata fields to
their extracted data. The keys of this dictionary correspond exactly
to those in metadata_extract_fns. If any fields fail to be extracted,
their corresponding values are set to None, and a warning is printed.
"""
if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` "
"does not exist.")
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
result: Dict[str, Any] = {}
for file in metadata_files:
with open(file) as f:
try:
metadata = json.load(f)
except json.JSONDecodeError:
print(f"Could not parse `{file}` as a valid metadata file,"
" skipping it.")
continue
if not isinstance(metadata, dict):
print(f"The file `{file}` does not correspond to a "
"JSON-serialized dictionary, skipping it.")
continue
for metadata_name, extract_fn in metadata_extract_fns.items():
try:
metadata_info = extract_fn(metadata)
if metadata_name not in result:
result[metadata_name] = metadata_info
elif metadata_info != result[metadata_name]:
raise RuntimeError(
"Metadata mismatch! Originally found "
f"{metadata_name} = {result[metadata_name]} but "
f"now found {metadata_name} = {metadata_info} in "
f"`{file}`")
except KeyError:
# It is possible that a given file does not contain some
# of our selected metadata as it could be located in some
# other metadata file.
# 'EFINAE': extract_fn failure is not an error.
pass
except ValueError:
# See above.
pass
# Warn if we cannot find any of the requested metadata
for metadata_name in metadata_extract_fns:
if metadata_name not in result:
print("WARNING: Unable to find requested metadata field "
f"`{metadata_name}`, setting it to None.")
result[metadata_name] = None
return result
def main(args):
metadata_extract_fns = {
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
"model_dtype": lambda json_dict: json_dict["dtype"]
}
recovered_metadata = _metadata_extractor(args.quantized_model,
metadata_extract_fns)
if args.tp_size is not None:
metadata_tp_size = recovered_metadata["tp_size"]
if metadata_tp_size is not None:
assert args.tp_size == metadata_tp_size, \
f"User expected TP world size = {args.tp_size} " \
f"but found TP world size = {metadata_tp_size} from metadata!"
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
rank_keyword = "rank"
hf_tensor_files, use_safetensors = _prepare_hf_weights(
args.quantized_model, args.load_format)
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
rank_keyword, expected_tp_size)
# Postprocess: formatting to the current schema. Consider pulling it
# out into a dedicated function should it ever become more complicated.
rank_scales_map = {
rank: {k: scale[k]
for k in sorted(scale.keys())}
for rank, scale in rank_scales_map.items()
}
# TODO: Expand this with activation and weights scaling factors when
# they are used in the future
schema = QuantParamSchema(
model_type=recovered_metadata["model_type"],
kv_cache={
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
recovered_metadata["model_dtype"]),
"scaling_factor":
rank_scales_map
},
)
if args.output_dir is None:
output_file = os.path.join(args.quantized_model, args.output_name)
else:
if not os.path.isdir(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
output_file = os.path.join(args.output_dir, args.output_name)
with open(output_file, 'w') as f:
f.write(schema.model_dump_json(indent=4))
print(f"Completed! KV cache scaling factors saved to {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="This simple utility extracts the "
"KV cache scaling factors from a quantized HF model "
"and saves them to a JSON file compatible with later "
"use by vLLM (pass this file to the appropriate "
"runtime typically using the argument "
"--quantization-param-path <filename>). This is only used "
"if the KV cache dtype is FP8 and on ROCm (hcu).")
parser.add_argument(
"--quantized-model",
help="Specify the directory containing a single quantized HF model. "
"It is expected that the quantization format is FP8_E4M3, for use "
"on ROCm (hcu).",
required=True)
parser.add_argument(
"--load_format",
help="Optionally specify the format of the model's tensor files "
"containing the KV cache scaling factors.",
choices=["auto", "safetensors", "npz", "pt"],
default="auto")
parser.add_argument(
"--output-dir",
help="Optionally specify the output directory. By default the "
"KV cache scaling factors will be saved in the model directory, "
"however you can override this behavior here.",
default=None)
parser.add_argument(
"--output-name",
help="Optionally specify the output filename.",
# TODO: Change this once additional scaling factors are enabled
default="kv_cache_scales.json")
parser.add_argument(
"--tp-size",
help="Optionally specify the tensor-parallel (TP) size that the "
"quantized model should correspond to. If specified, during KV "
"cache scaling factor extraction the observed TP size will be "
"checked against this and an error will be raised if there is "
"a mismatch. If not specified, the quantized model's expected "
"TP size is instead inferred from the largest TP rank observed. "
"The expected TP size is cross-checked against the TP ranks "
"observed in the quantized model and an error is raised if any "
"discrepancies are found.",
default=None,
type=int)
args = parser.parse_args()
main(args)
# Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型
## Overview
Medusa是一种大模型并行解码算法,除了支持官方提供的Top1-proposer,我们还支持tree-style并行解码,target model和draft model均可多卡推理
与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
## Support Matrix
* FP16
* BF16
* PAGED_KV_CACHE
* Tensor Parallel
### convert Medusa model weights
# medusa 模型需要转换为vllm中Medusa的模型格式
```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
```
此处model.bin是训练后保存的medusa head权重,如果希望采用Top1-proposer,medusa_choices可以不设置
### Run tree-style generation server
```bash
VLLM_TREE_DECODING=1 python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 9 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\
--num-speculative-heads 4 --num-speculative-tokens 24
```
注意:
num_speculative_tokens = len(medusa_choices) + 1
medusa_choices个数不能太多,否则多batch下会降低推理速度
speculative-disable-by-batch-size要大于max-num-seqs,否则当batch等于max-num-seqs时,不会走并行解码
### Run Top1-proposer server
python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 9 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\
--num-speculative-tokens 4
注意:
使用Top1-proposer时,num-speculative-tokens就是medusa head的个数
# do request
```bash
curl http://localhost:8086/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "qwen_medusa",
"prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n帮我写一个C++的快速排序算法<|im_end|>\n<|im_start|>assistant\n",
"max_tokens": 256,
"temperature": 0.0
}'
```
### Run tree-style benchmark
```bash
VLLM_TREE_DECODING=1 python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-heads 4 --num-speculative-tokens 24 --gpu-memory-utilization 0.95
```
### Run Top1-proposer benchmark
```bash
python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-tokens 4 --gpu-memory-utilization 0.95
```
可设置max-num-seqs对不同的batch进行性能测试
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import numpy as np
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
from vllm.lora.request import LoRARequest
def nullable_str(val: str):
if not val or val == "None":
return None
return val
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Only keep the first two turns of each conversation.
dataset = [data["prompt"] for data in dataset]
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def run_vllm(
warmup_requests: List[Tuple[str, int, int]],
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
# warmup
warmup_prompts = []
warmup_sampling_params = []
for prompt, _, output_len in warmup_requests:
warmup_prompts.append(prompt)
warmup_sampling_params.append(
SamplingParams(
n=n,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
if lora_modules is None:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
else:
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
total_out_tokens = 0
start = time.perf_counter()
if lora_modules is None:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
else:
outputs = llm.generate(prompts, sampling_params, use_tqdm=False,
lora_request=LoRARequest("medusa-lora", 1, lora_modules))
for output in outputs:
print("token_ids len:{} text:{}".format(len(output.outputs[0].token_ids), output.outputs[0].text))
total_out_tokens += len(output.outputs[0].token_ids)
end = time.perf_counter()
return end - start, total_out_tokens
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
max_num_seqs: int = 8,
speculative_model: str=None,
speculative_draft_tensor_parallel_size: int = 1,
speculative_disable_by_batch_size: int = 4,
spec_decoding_acceptance_method: str = None,
enable_lora: bool = False,
max_lora_rank: int = 32,
lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None,
num_speculative_heads: int = 5,
num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False,
lora_modules: str = None
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
max_num_seqs=max_num_seqs,
speculative_model=speculative_model,
speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size,
speculative_disable_by_batch_size=speculative_disable_by_batch_size,
spec_decoding_acceptance_method=spec_decoding_acceptance_method,
enable_lora=enable_lora,
max_lora_rank=max_lora_rank,
lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules,
num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=False,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
out_dict = {}
async for i, res in all_gens:
#print("res:", res)
out_dict[res.request_id] = len(res.outputs[0].token_ids)
end = time.perf_counter()
total_out_tokens = 0
for token_num in out_dict.values():
total_out_tokens += token_num
return end - start, total_out_tokens
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.async_engine:
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, False, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.lora_extra_vocab_size,
args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens
]
else:
run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc, args.max_num_seqs,
args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.lora_extra_vocab_size,
args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time, total_out_tokens = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time, total_out_tokens = run_vllm(*run_args, args.use_new_beam_search_impl, args.lora_modules)
total_num_tokens = total_out_tokens + sum(prompt_len
for _, prompt_len, _ in requests)
print(f"Latency: {elapsed_time:.2f} s")
print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=256,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (hcu) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
default="auto",
choices=DEVICE_OPTIONS,
help='device type for vLLM execution')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument('--merge-lora',
type=bool,
default=False,
help='If set to True, the weights of the base layer will be merged with the weights of Lora.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument('--lora-target-modules',
nargs='*',
default=None,
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
parser.add_argument(
'--num-speculative-heads',
type=int,
default=EngineArgs.num_speculative_heads,
help='The number of speculative heads to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--lora-modules',
type=nullable_str,
default=None,
help=
'Path of lora model.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
main(args)
\ No newline at end of file
import os
import ast
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
from addict import Dict
import yaml
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from safetensors.torch import save_model, safe_open
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'medusa_head.{}.1.weight'
TRAINED_BLOCK_BIAS_NAME_TEMPLATE = 'medusa_head.{}.{}.linear.bias'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_BLOCK_BIAS_NAME_TEMPLATE = 'blocks.{}.layers.{}.bias'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight'
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
hidden_size: int = 4096,
vocab_size: int = 32001,
num_heads: int = 5,
num_hidden_layers: int = 1,
max_paths: int = 64,
topk: int = 10,
truncated_vocab_size: Optional[int] = None,
**kwargs):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.max_paths = max_paths
self.topk = topk
self.max_seq_len = int(2**20)
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
else truncated_vocab_size
if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"]
super().__init__(**kwargs)
@property
def num_attention_heads(self):
return 0
@property
def num_lookahead_tokens(self):
return self.num_heads
@num_lookahead_tokens.setter
def num_lookahead_tokens(self, num_lookahead_tokens: int):
self.num_heads = num_lookahead_tokens
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.embedding_dim = embedding_dim
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_padded],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
def forward(self, input_):
masked_input = input_
# Get the embeddings.
output = F.embedding(masked_input.long(), self.weight)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size)
for _ in range(num_layers)
])
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x
class Medusa(nn.Module):
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
weights_map = {}
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")
if name == "token_map":
if self.truncated_vocab_size < self.orig_vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
loaded_weight.shape[0] > self.token_map.shape[0]:
loaded_weight = loaded_weight[self.token_map]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None)
class CustomMedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
name_or_path: str = "S-3000/vllm-medusa-qwen1.5-7b-chat",
architectures: list[str] = ["MedusaModel"],
hidden_size: int = 4096,
model_type: str = "medusa",
num_heads: int = 5,
num_hidden_layers: int = 1,
transformers_version: str = "4.41.2",
truncated_vocab_size: Optional[int] = None,
vocab_size: int = 151936,
medusa_choices:List[List[int]] = None,
**kwargs):
super().__init__(**kwargs)
self._name_or_path = name_or_path
self.architectures = architectures
self.hidden_size = hidden_size
self.model_type = model_type
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.transformers_version = transformers_version
self.truncated_vocab_size = truncated_vocab_size
self.vocab_size = vocab_size
self.medusa_choices = medusa_choices
def main(args):
medusa_head_num = args.medusa_num_heads
medusa_num_layers = args.medusa_num_layers
config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num)
medusa_model = Medusa(config)
params_dict = dict(medusa_model.named_parameters())
trained_medusa_model = torch.load(args.medusa_model_path)
for i in range(medusa_head_num):
vllm_medusa_head_weight_name = VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE.format(i)
trained_medusa_head_weight_name = TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE.format(i)
vllm_medusa_head_param = params_dict[vllm_medusa_head_weight_name]
trained_medusa_head_param = trained_medusa_model[trained_medusa_head_weight_name]
weight_loader = getattr(vllm_medusa_head_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_head_param, trained_medusa_head_param)
for i in range(medusa_head_num):
for j in range(medusa_num_layers):
# load linear weight
vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_param = params_dict[vllm_medusa_block_weight_name]
trained_medusa_block_param = trained_medusa_model[trained_medusa_block_weight_name]
weight_loader = getattr(vllm_medusa_block_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_param, trained_medusa_block_param)
# load linear bias
vllm_medusa_block_bias_name = VLLM_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
trained_medusa_block_bias_name = TRAINED_BLOCK_BIAS_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_bias_param = params_dict[vllm_medusa_block_bias_name]
trained_medusa_block_bias_param = trained_medusa_model[trained_medusa_block_bias_name]
weight_loader = getattr(vllm_medusa_block_bias_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_bias_param, trained_medusa_block_bias_param)
if not Path(args.output_dir).is_dir():
os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
medusa_choices = ast.literal_eval(args.medusa_choices) if args.medusa_choices is not None else None
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size,
num_heads=medusa_head_num,
num_hidden_layers=medusa_num_layers,
vocab_size=args.vocab_size,
medusa_choices=medusa_choices)
to_save_config.save_pretrained(args.output_dir)
# validate weight
# with safe_open(os.path.join(args.output_dir, "model.safetensors"), framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(3, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
parser.add_argument("--medusa_model_path", type=str, required=True,
help="Path to the medusa model file.")
parser.add_argument("--vocab_size", type=int, required=True,
help="Vocab size")
parser.add_argument("--medusa_num_heads", type=int, required=True,
help="Number of Medusa heads")
parser.add_argument("--medusa_num_layers", type=int, required=True,
help="Number of Medusa layers")
parser.add_argument("--hidden_size", type=int, required=True,
help="Hidden size")
parser.add_argument("--output_dir", type=str, required=True,
help="Output dir")
parser.add_argument(
'--medusa_choices',
type=str,
default=None,
help="Medusa choice to use, if not none, will use Medusa decoding."
" E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens."
)
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
import os
import json
import pytest
import torch
import triton
from triton_decode_attention import decode_attentionv1_fwd, decode_attentionv2_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [1])
# @pytest.mark.parametrize("L", [100])
@pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500])
@pytest.mark.parametrize("H_Q", [4, 8, 16])
@pytest.mark.parametrize("H_KV", [1])
@pytest.mark.parametrize("D_QK", [576])
@pytest.mark.parametrize("D_V", [512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 4
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) #这里为向上取整,65,(1027+16-1)//16
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
b_req_idx = torch.arange(B, device="cuda").to(torch.int32)
# Call the original implementation.
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
configs = {
"v2_tc": {"stage1": {}, "stage2": {}},
"v1_2stages_tc": {"stage1": {}, "stage2": {}},
}
ms = {
"v1_2stages_tc": 10000.0,
"v2_tc": 10000.0,
}
final_best_config = {
"kernel_kind": "",
"best_config": {},
"best_us": 0.0,
}
v2_tc_stage1_best_config, v2_tc_stage2_best_config = decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
quantiles = [0.5, 0.2, 0.8]
v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v2_tc_stage1_best_config.kwargs.items():
configs["v2_tc"]["stage1"][key] = value
configs["v2_tc"]["stage1"]["num_stages"] = v2_tc_stage1_best_config.num_stages
configs["v2_tc"]["stage1"]["num_warps"] = v2_tc_stage1_best_config.num_warps
for key, value in v2_tc_stage2_best_config.kwargs.items():
configs["v2_tc"]["stage2"][key] = value
configs["v2_tc"]["stage2"]["num_stages"] = v2_tc_stage2_best_config.num_stages
configs["v2_tc"]["stage2"]["num_warps"] = v2_tc_stage2_best_config.num_warps
ms["v2_tc"] = v2_tc_ms
print(f"v2_tc best configs is {configs['v2_tc']}")
print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
o2 = torch.zeros_like(o)
v1_tc_stage1_best_config, v1_tc_stage2_best_config = decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o2,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v1_tc_stage1_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage1"][key] = value
configs["v1_2stages_tc"]["stage1"]["num_stages"] = v1_tc_stage1_best_config.num_stages
configs["v1_2stages_tc"]["stage1"]["num_warps"] = v1_tc_stage1_best_config.num_warps
configs["v1_2stages_tc"]["stage1"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
for key, value in v1_tc_stage2_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage2"][key] = value
configs["v1_2stages_tc"]["stage2"]["num_stages"] = v1_tc_stage2_best_config.num_stages
configs["v1_2stages_tc"]["stage2"]["num_warps"] = v1_tc_stage2_best_config.num_warps
configs["v1_2stages_tc"]["stage2"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
ms["v1_2stages_tc"] = v1_tc_ms
min_key, min_ms = min(ms.items(), key=lambda x: x[1])
final_best_config["kernel_kind"] = min_key
final_best_config["best_config"] = configs[min_key]
final_best_config["best_us"] = min_ms * 1000
print(f"v1_2stages_tc best configs is {configs['v1_2stages_tc']}")
print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
print(f"Tuned_decode_attention choose {min_key} kernel, min cost {min_ms} ms, best config of {min_key} kernel is {configs[min_key]}")
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
#**************save config**************#
batch = b_req_idx.shape[0]
mean_seq_len = int((b_seq_len.sum() / max(1, batch)).item())
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_K100AI.json"
elif "BW" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
if os.path.exists(file_name):
with open(file_name, 'r') as file:
config_info = json.load(file)
else:
config_info = {}
# 如果 config_info 中没有当前的 batch,初始化它为一个空字典
# if f"{batch}" not in config_info:
# config_info[f"{batch}"] = {}
# 把新的 mean_seq_len 配置加入到当前 batch 中
# config_info[f"{batch}"][f"{mean_seq_len}"] = final_best_config
config_info[f"{mean_seq_len}"] = final_best_config
# 保存最佳配置
with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1)
#**************save config**************#
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import os
import logging
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
# and accuracy.
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored.")
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_kv_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
e_max = -float("inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
cur_kv_head * stride_buf_kh + offs_d[None, :])
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
cur_kv_head * stride_buf_vh + offs_dv[None, :])
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
acc *= re_scale
acc += tl.sum(p[:, None] * v, 0)
e_sum = e_sum * re_scale + tl.sum(p, 0)
e_max = n_e_max
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + offs_dv)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum,
mask=(mask_dv),
)
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + Lv)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
)
def _decode_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 64
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]
num_warps = 4 if kv_group_num == 1 else 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
_fwd_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
Lk=Lk,
Lv=Lv,
)
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
o,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
mask=mask_d,
other=0.0)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
**extra_kargs,
)
def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
# opt
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
)
@triton.jit
def _decode_v1_kernel_stage1_use_tc(
Q,
K_Buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
SPLIT_K: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_k_id = tl.program_id(2)
#reduce_dtype = Att_Out.dtype.element_ty
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(
Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
)#.to(reduce_dtype)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0) #.to(reduce_dtype)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
split_k_start = kv_len_per_split * split_k_id
split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
for start_n in range(split_k_start, split_k_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_k_end,
other=0,
)
k_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
other=0.0,
) #.to(reduce_dtype)
qk = tl.dot(q, k)
if BLOCK_DPE > 0:
offs_buf_kpe = (
k_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n[None, :] < split_k_end,
other=0.0,
) #.to(reduce_dtype)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
offs_o = cur_head[:, None] * att_stride_h + (
cur_batch_in_all_start_index + offs_n[None, :]
)
tl.store(
Att_Out + offs_o,
qk,
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
)
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
)
@triton.jit
def _decode_v1_kernel_stage2_use_tc(
logits,
V_Buffer,
Out,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
stride_logic_h,
stride_buf_vbs,
stride_buf_vh,
stride_obs,
stride_oh,
stride_req_to_token_b,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_H: tl.constexpr,
PAGE_SIZE: tl.constexpr,
Lv: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
mask_h = mask_h & (cur_head < q_head_num)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = cur_batch #tl.load(B_req_idx + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
v_ptrs = V_Buffer + offs_buf_v
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_page_number = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n) // PAGE_SIZE,
mask=(start_n + offs_n) < cur_batch_seq_len,
other=0,
)
v_loc = v_page_number * PAGE_SIZE + (start_n + offs_n) % PAGE_SIZE
offs_qk = cur_head[:, None] * stride_logic_h + (
cur_batch_start_loc + start_n + offs_n[None, :]
)
qk = tl.load(
logits + offs_qk,
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
other=float("-inf"),
) #[head, block_n]
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
e_sum = e_sum * old_scale + tl.sum(p, 1)
v = tl.load(
v_ptrs + v_loc[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
) #[block_n,head_dim]
p = p.to(v.dtype)
acc = acc * old_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
acc = acc / e_sum[:, None]
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
def _decode_v1_stage1_use_tc(
q,
k_buffer,
att_out,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
):
Lk = k_buffer.shape[-1]
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K,
)
_decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
kpack=2,
)
return _decode_v1_kernel_stage1_use_tc.best_config
def _decode_v1_stage2_use_tc(
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
page_size,
):
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv)
_decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size,
Lv=Lv,
)
return _decode_v1_kernel_stage2_use_tc.best_config
def decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
q,
k_buffer,
attn_logits,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
sm_scale,
page_size,
num_kv_splits,
logit_cap,
)
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
attn_logits,
v_buffer,
o,
req_to_token,
#b_req_idx,
b_start_loc,
b_seq_len,
page_size,
)
return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=4, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
)
@triton.jit
def _decode_v2_kernel_stage1_use_tc(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DIM: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
# offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
# mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx = cur_batch
# offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
# q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
)
qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
NUM_DIM_SPLIT = tl.cdiv(BLOCK_DMODEL, BLOCK_DIM)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32)
for i in range(0, NUM_DIM_SPLIT):
offs_d = tl.arange(0, BLOCK_DIM) + i * BLOCK_DIM
mask_d = offs_d < Lk
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None,:]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
offs_buf_k = kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None]
k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0)
qk += tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
offs_buf_v = (
kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (
cur_batch * stride_mid_ob
+ cur_head[:, None] * stride_mid_oh
+ split_kv_id * stride_mid_os
+ offs_dv[None, :]
)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ split_kv_id * stride_mid_os
+ Lv
)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_v2_stage1_use_tc(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = lambda META: (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
_decode_v2_kernel_stage1_use_tc[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
Lk=Lk,
Lv=Lv,
kpack=2,
)
return _decode_v2_kernel_stage1_use_tc.best_config
@triton.autotune(
configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1),
triton.Config({}, num_warps=8, num_stages=1),
],
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
)
@triton.jit
def _decode_v2_kernel_stage2(
Mid_O,
O,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
def _decode_v2_stage2_use_tc(
logits,
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num)
_decode_v2_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
)
return _decode_v2_kernel_stage2.best_config
def decode_attention_v2(
q,
k_buffer,
v_buffer,
o,
req_to_token,
# b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
# b_req_idx,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
def decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size=1,
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
num_b = min(kv_group_num, 16)
grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0]
L = req_to_token.shape[1]*page_size
if grid_num * num_kv_splits < 128:
num_kv_splits = (127 + grid_num) // grid_num
attn_logits_v1 = torch.empty(
(q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1),
dtype=torch.float32,
device="cuda",
)
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
# GQA/MQA/MLA
v2_tc_stage1_best_config, v2_tc_stage2_best_config = decode_attention_v2(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
return v2_tc_stage1_best_config, v2_tc_stage2_best_config
def decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size=1,
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1]*page_size, req_to_token.shape[0]*req_to_token.shape[1]* page_size // q.shape[0], device="cuda").to(torch.int32)
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
attn_logits_v1 = torch.empty(
(q.shape[1],req_to_token.shape[0]*req_to_token.shape[1]*page_size),
dtype=torch.float32,
device="cuda")
v1_tc_stage1_best_config, v1_tc_stage2_best_config = decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
return v1_tc_stage1_best_config, v1_tc_stage2_best_config
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
question_per_audio_count = {
0: "What is 1+1?",
1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?",
}
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
# optimal for granite speech, and it is generally recommended to use beam
# search. Check the model README for suggested settings.
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
model_name = "ibm-granite/granite-speech-3.3-8b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=2048,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=64,
limit_mm_per_prompt={"audio": audio_count},
)
# The model has an audio-specific lora directly in its model dir;
# it should be enabled whenever you pass audio inputs to the model.
speech_lora_path = model_name
audio_placeholder = "<|audio|>" * audio_count
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
audio_placeholder = "(<audio>./</audio>)" * audio_count
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=audio_chat_template,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=stop_token_ids,
)
# Phi-4-multimodal-instruct
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
max_model_len=12800,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join(
[
f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
for idx in range(audio_count)
]
)
prompt = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
model_name = "Qwen/Qwen2.5-Omni-7B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join(
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
)
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "Whisper only support single audio input per prompt"
model_name = "openai/whisper-large-v3-turbo"
prompt = "<|startoftranscript|>"
engine_args = EngineArgs(
model=model_name,
max_model_len=448,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
model_example_map = {
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox,
"whisper": run_whisper,
}
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--num-prompts", type=int, default=1, help="Number of prompts to run."
)
parser.add_argument(
"--num-audios",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
audio_count = args.num_audios
req_data = model_example_map[model](
question_per_audio_count[audio_count], audio_count
)
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}
assert args.num_prompts > 0
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
# Add LoRA request if applicable
lora_request = (
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
outputs = llm.generate(
inputs,
sampling_params=sampling_params,
lora_request=lora_request,
)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
args = parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstration script for Automatic Prefix Caching (APC) in vLLM.
Automatic Prefix Caching (APC) allows the vLLM engine to reuse cached
KV (key-value) pairs from previous prompts if a new query shares the same
prefix. This reduces redundant computation and improves inference speed.
To enable APC, set `enable_prefix_caching=True` when initializing the
vLLM engine.
This script uses a long Markdown table as the shared prompt prefix and
compares the generation time for two queries that share the same prefix
but ask different questions.
Run:
python examples/offline_inference/automatic_prefix_caching.py
"""
import time
from vllm import LLM, SamplingParams
# ruff: noqa: E501
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = (
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
+ """
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
"""
)
def get_generation_time(llm, sampling_params, prompts):
# time the generation
start_time = time.time()
output = llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
# print the output and generation time
print("-" * 30)
print(f"Output: {output[0].outputs[0].text}")
print(f"Generation time: {end_time - start_time} seconds.")
print("-" * 30)
def main():
# set enable_prefix_caching=True to enable APC
llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
sampling_params = SamplingParams(temperature=0, max_tokens=100)
# Querying the age of John Doe
get_generation_time(
llm,
sampling_params,
LONG_PROMPT
+ "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
)
# Querying the age of Zack Blue
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
get_generation_time(
llm,
sampling_params,
LONG_PROMPT
+ "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
)
if __name__ == "__main__":
main()
# Basic
The `LLM` class provides the primary Python interface for doing offline inference, which is interacting with a model without using a separate model inference server.
## Usage
The first script in this example shows the most basic usage of vLLM. If you are new to Python and vLLM, you should start here.
```bash
python examples/offline_inference/basic/basic.py
```
The rest of the scripts include an [argument parser](https://docs.python.org/3/library/argparse.html), which you can use to pass any arguments that are compatible with [`LLM`](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html). Try running the script with `--help` for a list of all available arguments.
```bash
python examples/offline_inference/basic/classify.py
```
```bash
python examples/offline_inference/basic/embed.py
```
```bash
python examples/offline_inference/basic/score.py
```
The chat and generate scripts also accept the [sampling parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters): `max_tokens`, `temperature`, `top_p` and `top_k`.
```bash
python examples/offline_inference/basic/chat.py
```
```bash
python examples/offline_inference/basic/generate.py
```
## Features
In the scripts that support passing arguments, you can experiment with the following features.
### Default generation config
The `--generation-config` argument specifies where the generation config will be loaded from when calling `LLM.get_default_sampling_params()`. If set to ‘auto’, the generation config will be loaded from model path. If set to a folder path, the generation config will be loaded from the specified folder path. If it is not provided, vLLM defaults will be used.
> If max_new_tokens is specified in generation config, then it sets a server-wide limit on the number of output tokens for all requests.
Try it yourself with the following argument:
```bash
--generation-config auto
```
### Quantization
#### AQLM
vLLM supports models that are quantized using AQLM.
Try one yourself by passing one of the following models to the `--model` argument:
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf`
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf`
- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf`
- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf`
- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf`
> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs.
#### GGUF
vLLM supports models that are quantized using GGUF.
Try one yourself by downloading a quantized GGUF model and using the following arguments:
```python
from huggingface_hub import hf_hub_download
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
print(hf_hub_download(repo_id, filename=filename))
```
```bash
--model {local-path-printed-above} --tokenizer microsoft/Phi-3-medium-4k-instruct
```
### CPU offload
The `--cpu-offload-gb` argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight, which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model is loaded from CPU memory to GPU memory on the fly in each model forward pass.
Try it yourself with the following arguments:
```bash
--model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
```
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Hello, my name is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main():
# Create an LLM.
llm = LLM(model="/mnt/data/llm-models/qwen3/Qwen3-8B",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True, block_size=64, enable_prefix_caching=False)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
# Add example params
parser.add_argument("--chat-template-path", type=str)
return parser
def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")
chat_template_path = args.pop("chat_template_path")
# Create an LLM
llm = LLM(**args)
# Create sampling params object
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k
def print_outputs(outputs):
print("\nGenerated Outputs:\n" + "-" * 80)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)
print("=" * 80)
# In this script, we demonstrate how to pass input to the chat method:
conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
print_outputs(outputs)
# You can run batch inference with llm.chat API
conversations = [conversation for _ in range(10)]
# We turn on tqdm progress bar to verify it's indeed running batch inference
outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
print_outputs(outputs)
# A chat template can be optionally supplied.
# If not, the model will use its default chat template.
if chat_template_path is not None:
with open(chat_template_path) as f:
chat_template = f.read()
outputs = llm.chat(
conversations,
sampling_params,
use_tqdm=False,
chat_template=chat_template,
)
if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment